diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/README.md b/README.md index 7be5fc7f47d5db027d120b8024982df93db95b74..9cca52fcf6655002da130c7be9592e0faabc610a 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,145 @@ ---- -license: mit ---- +# TransfoRNA +TransfoRNA is a **bioinformatics** and **machine learning** tool based on **Transformers** to provide annotations for 11 major classes (miRNA, rRNA, tRNA, snoRNA, protein +-coding/mRNA, lncRNA, YRNA, piRNA, snRNA, snoRNA and vtRNA) and 1923 sub-classes +for **human small RNAs and RNA fragments**. These are typically detected by RNA-seq NGS (next generation sequencing) data. + +TransfoRNA can be trained on just the RNA sequences and optionally on additional information such as secondary structure. The result is a major and sub-class assignment combined with a novelty score (Normalized Levenshtein Distance) that quantifies the difference between the query sequence and the closest match found in the training set. Based on that it decides if the query sequence is novel or familiar. TransfoRNA uses a small curated set of ground truth labels obtained from common knowledge-based bioinformatics tools that map the sequences to transcriptome databases and a reference genome. Using TransfoRNA's framewok, the high confidence annotations in the TCGA dataset can be increased by 3 folds. + + +## Dataset (Objective): +- **The Cancer Genome Atlas, [TCGA](https://www.cancer.gov/about-nci/organization/ccg/research/structural-genomics/tcga)** offers sequencing data of small RNAs and is used to evaluate TransfoRNAs classification performance + - Sequences are annotated based on a knowledge-based annotation approach that provides annotations for ~2k different sub-classes belonging to 11 major classes. + - Knowledge-based annotations are divided into three sets of varying confidence levels: a **high-confidence (HICO)** set, a **low-confidence (LOCO)** set, and a **non-annotated (NA)** set for sequences that could not be annotated at all. Only HICO annotations are used for training. + - HICO RNAs cover ~2k sub-classes and constitute 19.6% of all RNAs found in TCGA. LOCO and NA sets comprise 66.9% and 13.6% of RNAs, respectively. + - HICO RNAs are further divided into **in-distribution, ID** (374 sub-classes) and **out-of-distribution, OOD** (1549 sub-classes) sets. + - Criteria for ID and OOD: Sub-class containing more than 8 sequences are considered ID, otherwise OOD. + - An additional **putative 5' adapter affixes set** contains 294 sequences known to be technical artefacts. The 5’-end perfectly matches the last five or more nucleotides of the 5’-adapter sequence, commonly used in small RNA sequencing. + - The knowledge-based annotation (KBA) pipline including installation guide is located under `kba_pipline` + +## Models +There are 5 classifier models currently available, each with different input representation. + - Baseline: + - Input: (single input) Sequence + - Model: An embedding layer that converts sequences into vectors followed by a classification feed forward layer. + - Seq: + - Input: (single input) Sequence + - Model: A transformer based encoder model. + - Seq-Seq: + - Input: (dual inputs) Sequence divided into even and odd tokens. + - Model: A transformer encoder is placed for odd tokens and another for even tokens. + - Seq-Struct: + - Input: (dual inputs) Sequence + Secondary structure + - Model: A transformer encoder for the sequence and another for the secondary structure. + - Seq-Rev (best performant): + - Input: (dual inputs) Sequence + - Model: A transformer encoder for the sequence and another for the sequence reversed. + + +*Note: These (Transformer) based models show overlapping and distinct capabilities. Consequently, an ensemble model is created to leverage those capabilities.* + + +Screenshot 2023-08-16 at 16 39 20 + +## Data Availability +Downloading the data and the models can be done from [here](https://www.dropbox.com/sh/y7u8cofmg41qs0y/AADvj5lw91bx7fcDxghMbMtsa?dl=0). + +This will download three subfolders that should be kept on the same folder level as `src`: + - `data`: Contains three files: + - `TCGA` anndata with ~75k sequences and `var` columns containing the knowledge based annotations. + - `HBDXBase.csv` containing a list of RNA precursors which are then used for data augmentation. + - `subclass_to_annotation.json` holds mappings for every sub-class to major-class. + + - `models`: + - `benchmark` : contains benchmark models trained on sncRNA and premiRNA data. (See additional datasets at the bottom) + - `tcga`: All models trained on the TCGA data; `TransfoRNA_ID` (for testing and validation) and `TransfoRNA_FULL` (the production version) containing higher RNA major and sub-class coverage. Each of the two folders contain all the models trained seperately on major-class and sub-class. + - `kba_pipeline`: contains mapping reference data required to run the knowledge based pipeline manually +## Repo Structure +- configs: Contains the configurations of each model, training and inference settings. + + The `conf/main_config.yaml` file offers options to change the task, the training settings and the logging. The following shows all the options and permitted values for each option. + + Screenshot 2024-05-22 at 10 19 15 + +- transforna contains two folders: + - `src` folder which contains transforna package. View transforna's architecture [here](https://github.com/gitHBDX/TransfoRNA/blob/master/transforna/src/readme.md). + - `bin` folder contains all scripts necessary for reproducing manuscript figures. + +## Installation + + The `install.sh` is a script that creates an transforna environment in which all the required packages for TransfoRNA are installed. Simply navigate to the root directory and run from terminal: + + ``` + #make install script executable + chmod +x install.sh + + + #run script + ./install.sh + ``` + +## TransfoRNA API + In `transforna/src/inference/inference_api.py`, all the functionalities of transforna are offered as APIs. There are two functions of interest: + - `predict_transforna` : Computes for a set of sequences and for a given model, one of various options; the embeddings, logits, explanatory (similar) sequences, attentions masks or umap coordinates. + - `predict_transforna_all_models`: Same as `predict_transforna` but computes the desired option for all the models as well as aggregates the output of the ensemble model. + Both return a pandas dataframe containing the sequence along with the desired computation. + + Check the script at `src/test_inference_api.py` for a basic demo on how to call the either of the APIs. + +## Inference from terminal +For inference, two paths in `configs/inference_settings/default.yaml` have to be edited: + - `sequences_path`: The full path to a csv file containing the sequences for which annotations are to be inferred. + - `model_path`: The full path of the model. (currently this points to the Seq model) + +Also in the `main_config.yaml`, make sure to edit the `model_name` to match the input expected by the loaded model. + - `model_name`: add the name of the model. One of `"seq"`,`"seq-seq"`,`"seq-struct"`,`"baseline"` or `"seq-rev"` (see above) + + +Then, navigate the repositories' root directory and run the following command: + +``` +python transforna/__main__.py inference=True +``` + +After inference, an `inference_output` folder will be created under `outputs/` which will include two files. + - `(model_name)_embedds.csv`: contains vector embedding per sequence in the inference set- (could be used for downstream tasks). + *Note: The embedds of each sequence will only be logged if `log_embedds` in the `main_config` is `True`.* + - `(model_name)_inference_results.csv`: Contains columns; Net-Label containing predicted label and Is Familiar? boolean column containing the models' novelty predictor output. (True: familiar/ False: Novel) + *Note: The output will also contain the logits of the model is `log_logits` in the `main_config` is `True`.* + + +## Train on custom data +TransfoRNA can be trained using input data as Anndata, csv or fasta. If the input is anndata, then `anndata.var` should contains all the sequences. Some changes has to be made (follow `configs/train_model_configs/tcga`): + +In `configs/train_model_configs/custom`: +- `dataset_path_train` has to point to the input_data which should contain; a `sequence` column, a `small_RNA_class_annotation` coliumn indicating the major class if available (otherwise should be NaN), `five_prime_adapter_filter` specifies whether the sequence is considered a real sequence or an artifact (`True `for Real and `False` for artifact), a `subclass_name` column containing the sub-class name if available (otherwise should be NaN), and a boolean column `hico` indicating whether a sequence is high confidence or not. +- If sampling from the precursor is required in order to augment the sub-classes, the `precursor_file_path` should include precursors. Follow the scheme of the HBDxBase.csv and have a look at `PrecursorAugmenter` class in `transforna/src/processing/augmentation.py` +- `mapping_dict_path` should contain the mapping from sub class to major class. i.e: 'miR-141-5p' to 'miRNA'. +- `clf_target` sets the classification target of the mopdel and should be either `sub_class_hico` for training on targets in `subclass_name` or `major_class_hico` for training on targets in `small_RNA_class_annotation`. For both, only high confidence sequences are selected for training (based on `hico` column). + +In configs/main_config, some changes should be made: +- change `task` to `custom` or to whatever name the `custom.py` has been renamed. +- set the `model_name` as desired. + +For training TransfoRNA from the root directory: +``` +python transforna/__main__.py +``` +Using [Hydra](https://hydra.cc/), any option in the main config can be changed. For instance, to train a `Seq-Struct` TransfoRNA model without using a validation split: +``` +python transforna/__main__.py train_split=False model_name='seq-struct' +``` +After training, an output folder is automatically created in the root directory where training is logged. +The structure of the output folder is chosen by hydra to be `/day/time/results folders`. Results folders are a set of folders created during training: +- `ckpt`: (containing the latest checkpoint of the model) +- `embedds`: + - Contains a file per each split (train/valid/test/ood/na). + - Each file is a `csv` containing the sequences plus their embeddings (obtained by the model and represent numeric representation of a given RNA sequence) as well as the logits. The logits are values the models produce for each sequence, reflecting its confidence of a sequence belonging to a certain class. +- `meta`: A folder containing a `yaml` file with all the hyperparameters used for the current run. +- `analysis`: contains the learned novelty threshold seperating the in-distribution set(Familiar) from the out of distribution set (Novel). +- `figures`: some figures are saved containing the Normalized Levenstein Distance NLD, distribution per split. + + +## Additional Datasets (Objective): +- sncRNA, collected from [RFam](https://rfam.org/) (classification of RNA precursors into 13 classes) +- premiRNA [human miRNAs](http://www.mirbase.org)(classification of true vs pseudo precursors) + diff --git a/conf/__init__.py b/conf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conf/hydra/job_logging/custom.yaml b/conf/hydra/job_logging/custom.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60012b7a807f1791cc7982bf896c7e04674d5c76 --- /dev/null +++ b/conf/hydra/job_logging/custom.yaml @@ -0,0 +1,13 @@ +version: 1 +formatters: + simple: + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' +handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout +root: + handlers: [console] + +disable_existing_loggers: false \ No newline at end of file diff --git a/conf/inference_settings/default.yaml b/conf/inference_settings/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..227acec283d633cdc3e1f86345d9bd6dc9ae0f2e --- /dev/null +++ b/conf/inference_settings/default.yaml @@ -0,0 +1,3 @@ +infere_original_testset: false +model_path: /nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/Seq/ckpt/model_params_tcga.pt +sequences_path: /nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/data/inference_set.csv diff --git a/conf/main_config.yaml b/conf/main_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c49311a625775f8011059784996ea4517a37b084 --- /dev/null +++ b/conf/main_config.yaml @@ -0,0 +1,51 @@ +defaults: + - model: transforna + - inference_settings: default + - override hydra/job_logging: disabled + +task: tcga # tcga,sncrna or premirna or custom (for a custom dataset) + + +train_config: + _target_: train_model_configs.${task}.GeneEmbeddTrainConfig + +model_config: + _target_: train_model_configs.${task}.GeneEmbeddModelConfig + +#train settings +model_name: seq #seq, seq-seq, seq-struct, seq-reverse, or baseline +trained_on: full #full(production, more coverage) or id (for test/eval purposes) +path_to_models: /nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/ #edit path to point to models/tcga/ directory: will be used if trained_on is full +inference: False # Should TransfoRNA be used for inference or train? True or False +#if inference is true, should the logits be logged? +log_logits: False + + +train_split: True # True or False +valid_size: 0.15 # 0 < valid_size < 1 + +#CV +cross_val: True # True or False +num_replicates: 1 # Integer, num_replicates for cross-validation + +#seed +seed: 1 # Integer +device_number: 1 # Integer, select GPU + + +#logging sequence embeddings + metrics to tensorboard +log_embedds: True # True or False +tensorboard: False # True or False + +#disable hydra output +hydra: + run: + dir: ./outputs/${now:%Y-%m-%d}/${model_name} + searchpath: + - file:///nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/conf +# output_subdir: null #uncomment to disable hydra output + + + + + diff --git a/conf/model/transforna.yaml b/conf/model/transforna.yaml new file mode 100644 index 0000000000000000000000000000000000000000..148f733ab7121d73e72810f97582551546a4a7fb --- /dev/null +++ b/conf/model/transforna.yaml @@ -0,0 +1,9 @@ +skorch_model: + _target_: transforna.Net + module: transforna.GeneEmbeddModel + criterion: transforna.LossFunction + max_epochs: 0 #infered from task specific train config + optimizer: torch.optim.AdamW + device: cuda + batch_size: 64 + iterator_train__shuffle: True \ No newline at end of file diff --git a/conf/readme.md b/conf/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..486fa180ae4072f4b862541111fc2ca9c0f67d7d --- /dev/null +++ b/conf/readme.md @@ -0,0 +1,17 @@ +The `inference_settings` has a default yaml containing four keys: + -`sequences_path`: The full path of the file containing the sequences for which their annotations are to be infered. + - `model_path`: the full path of the model to be used for inference. + - `model_name`: A model name indicating the inputs the model expects. One of `seq`,`seq-seq`,`seq-struct`,`seq-reverse` or `baseline` + - `infere_original_testset`: True/False indicating whether inference should be computed on the original test set. + +`model` contains the skeleton of the model used, the optimizer, loss function and device. All models are built using [skorch](https://skorch.readthedocs.io/en/latest/?badge=latest) + +`train_model_configs` contain the hyperparameters for each dataset; tcga, sncrna and premirna: + + - Each file contains the model and the train config. + + - Model config: contains the model hyperparameters, sequence tokenization scheme and allows for choosing the model. + + - Train config: contains training settings such as the learning rate hyper parameters as well as `dataset_path_train`. + - `dataset_path_train`: should point to the dataset [(Anndata)](https://anndata.readthedocs.io/en/latest/) used for training. + diff --git a/conf/train_model_configs/__init__.py b/conf/train_model_configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/conf/train_model_configs/custom.py b/conf/train_model_configs/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0cfbd5bcdcf783b062b979b7bef734f85d7ca5 --- /dev/null +++ b/conf/train_model_configs/custom.py @@ -0,0 +1,76 @@ +import math +import os +from dataclasses import dataclass, field +from typing import Dict, List + +dirname, _ = os.path.split(os.path.dirname(__file__)) + + +@dataclass +class GeneEmbeddModelConfig: + + model_input: str = "" #will be infered + + num_embed_hidden: int = 100 #30 for exp, 100 for rest + ff_input_dim:int = 0 #is infered later, equals gene expression len + ff_hidden_dim: List = field(default_factory=lambda: [300]) #300 for exp hico + feed_forward1_hidden: int = 256 + num_attention_project: int = 64 + num_encoder_layers: int = 1 + dropout: float = 0.2 + n: int = 121 + relative_attns: List = field(default_factory=lambda: [29, 4, 6, 8, 10, 11]) + num_attention_heads: int = 5 + + window: int = 2 + tokens_len: int = math.ceil(max_length / window) + second_input_token_len: int = 0 # is infered during runtime + vocab_size: int = 0 # is infered during runtime + second_input_vocab_size: int = 0 # is infered during runtime + tokenizer: str = ( + "overlap" # either overlap or no_overlap or overlap_multi_window + ) + + clf_target:str = 'm' # sub_class_hico or major_class_hico. hico = high confidence + num_classes: int = 0 #will be infered during runtime + class_mappings:List = field(default_factory=lambda: [])#will be infered during runtime + class_weights :List = field(default_factory=lambda: []) + # how many extra window sizes other than deafault window + temperatures: List = field(default_factory=lambda: [0,10]) + + tokens_mapping_dict: Dict = None + false_input_perc:float = 0.0 + + +@dataclass +class GeneEmbeddTrainConfig: + dataset_path_train: str = 'path/to/anndata.h5ad' + precursor_file_path: str = 'path/to/precursor_file.csv' #if not provided, sampling from the precurosr will not be done + mapping_dict_path: str = 'path/to/mapping_dict.json' #required for mapping sub class to major class, i.e: mir-568-3p to miRNA + device: str = "cuda" + l2_weight_decay: float = 0.05 + batch_size: int = 512 + + batch_per_epoch:int = 0 # is infered during runtime + + label_smoothing_sim:float = 0.2 + label_smoothing_clf:float = 0.0 + # learning rate + learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to' + lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section + lr_warmup_end: float = 1 # end of linear warmup section , annealing begin + # TODO: 122 is the number of train batches per epoch, should be infered and set + # warmup batch should be during the form epoch*(train batch per epoch) + warmup_epoch: int = 10 # how many batches linear warm up for + final_epoch: int = 20 # final batch of training when want learning rate + + top_k: int = 10#int(0.1 * batch_size) # if the corresponding rna/GE appears during the top k, the correctly classified + cross_val: bool = False + labels_mapping_path: str = None + filter_seq_length:bool = False + + num_augment_exp:int = 20 + shuffle_exp: bool = False + + max_epochs: int = 3000 + diff --git a/conf/train_model_configs/premirna.py b/conf/train_model_configs/premirna.py new file mode 100644 index 0000000000000000000000000000000000000000..986be8a87e62c661b0853b0119279d4a496c42eb --- /dev/null +++ b/conf/train_model_configs/premirna.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field +from typing import List + + +@dataclass +class GeneEmbeddModelConfig: + # input dim for the embedding and positional encoders + # as well as all k,q,v input/output dims + model_input: str = "seq-struct" + num_embed_hidden: int = 256 + ff_hidden_dim: List = field(default_factory=lambda: [1200, 800]) + feed_forward1_hidden: int = 1024 + num_attention_project: int = 64 + num_encoder_layers: int = 1 + dropout: float = 0.5 + n: int = 121 + relative_attns: List = field(default_factory=lambda: [int(112), int(112), 6*3, 8*3, 10*3, 11*3]) + num_attention_heads: int = 1 + + window: int = 2 + tokens_len: int = 0 #will be infered later + second_input_token_len: int = 0 # is infered in runtime + vocab_size: int = 0 # is infered in runtime + second_input_vocab_size: int = 0 # is infered in runtime + tokenizer: str = ( + "overlap" # either overlap or no_overlap or overlap_multi_window + ) + num_classes: int = 0 #will be infered in runtime + class_weights :List = field(default_factory=lambda: []) + tokens_mapping_dict: dict = None + + #false input percentage + false_input_perc:float = 0.1 + model_input: str = "seq-struct" + +@dataclass +class GeneEmbeddTrainConfig: + dataset_path_train: str = "/data/hbdx_ldap_local/analysis/data/premirna/train" + dataset_path_test: str = "/data/hbdx_ldap_local/analysis/data/premirna/test" + datset_path_additional_testset: str = "/data/hbdx_ldap_local/analysis/data/premirna/" + labels_mapping_path:str = "/data/hbdx_ldap_local/analysis/data/premirna/labels_mapping_dict.pkl" + device: str = "cuda" + l2_weight_decay: float = 1e-5 + batch_size: int = 64 + + batch_per_epoch: int = 0 #will be infered later + label_smoothing_sim:float = 0.0 + label_smoothing_clf:float = 0.0 + + # learning rate + learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to' + lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section + lr_warmup_end: float = 1 # end of linear warmup section , annealing begin + # TODO: 122 is the number of train batches per epoch, should be infered and set + # warmup batch should be in the form epoch*(train batch per epoch) + warmup_epoch: int = 10 # how many batches linear warm up for + final_epoch: int = 20 # final batch of training when want learning rate + + top_k: int = int( + 0.05 * batch_size + ) # if the corresponding rna/GE appears in the top k, the correctly classified + label_smoothing: float = 0.0 + cross_val: bool = False + filter_seq_length:bool = True + train_epoch: int = 3000 + max_epochs: int = 3500 + + diff --git a/conf/train_model_configs/sncrna.py b/conf/train_model_configs/sncrna.py new file mode 100644 index 0000000000000000000000000000000000000000..42b8bb93dbedc0a691e68509dbbca3365a2ea3a9 --- /dev/null +++ b/conf/train_model_configs/sncrna.py @@ -0,0 +1,70 @@ +import os +from dataclasses import dataclass, field +from pickletools import int4 +from typing import List + + +@dataclass +class GeneEmbeddModelConfig: + # input dim for the embedding and positional encoders + # as well as all k,q,v input/output dims + model_input: str = "seq-struct" + num_embed_hidden: int = 256 + ff_hidden_dim: List = field(default_factory=lambda: [1200, 800]) + feed_forward1_hidden: int = 1024 + num_attention_project: int = 64 + num_encoder_layers: int = 2 + dropout: float = 0.3 + n: int = 121 + window:int = 4 + relative_attns: List = field(default_factory=lambda: [int(360), int(360)]) + num_attention_heads: int = 4 + + tokens_len: int = 0 #will be infered later + second_input_token_len:int = 0 # is infered in runtime + vocab_size: int = 0 # is infered in runtime + second_input_vocab_size: int = 0 # is infered in runtime + tokenizer: str = ( + "overlap" # either overlap or no_overlap or overlap_multi_window + ) + # how many extra window sizes other than deafault window + num_classes: int = 0 #will be infered in runtime + class_weights :List = field(default_factory=lambda: []) + tokens_mapping_dict: dict = None + + #false input percentage + false_input_perc:float = 0.2 + + model_input: str = "seq-struct" + + +@dataclass +class GeneEmbeddTrainConfig: + dataset_path_train: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/train.h5ad" + dataset_path_test: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/test.h5ad" + labels_mapping_path:str = "/data/hbdx_ldap_local/analysis/data/sncRNA/labels_mapping_dict.pkl" + device: str = "cuda" + l2_weight_decay: float = 1e-5 + batch_size: int = 64 + + batch_per_epoch:int = 0 #will be infered later + label_smoothing_sim:float = 0.0 + label_smoothing_clf:float = 0.0 + + # learning rate + learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to' + lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section + lr_warmup_end: float = 1 # end of linear warmup section , annealing begin + # TODO: 122 is the number of train batches per epoch, should be infered and set + # warmup batch should be in the form epoch*(train batch per epoch) + warmup_epoch: int = 10 # how many batches linear warm up for + final_epoch: int = 20 # final batch of training when want learning rate + + top_k: int = int( + 0.05 * batch_size + ) # if the corresponding rna/GE appears in the top k, the correctly classified + label_smoothing: float = 0.0 + cross_val: bool = False + filter_seq_length:bool = True + train_epoch: int = 800 + max_epochs:int = 1000 \ No newline at end of file diff --git a/conf/train_model_configs/tcga.py b/conf/train_model_configs/tcga.py new file mode 100644 index 0000000000000000000000000000000000000000..6ee8fe237eb60e1a2021b4c3a6b2a575ccecc73f --- /dev/null +++ b/conf/train_model_configs/tcga.py @@ -0,0 +1,81 @@ +import math +import os +from dataclasses import dataclass, field +from typing import Dict, List + +dirname, _ = os.path.split(os.path.dirname(__file__)) + + +@dataclass +class GeneEmbeddModelConfig: + + model_input: str = "" #will be infered + + num_embed_hidden: int = 100 #30 for exp, 100 for rest + ff_input_dim:int = 0 #is infered later, equals gene expression len + ff_hidden_dim: List = field(default_factory=lambda: [300]) #300 for exp hico + feed_forward1_hidden: int = 256 + num_attention_project: int = 64 + num_encoder_layers: int = 1 + dropout: float = 0.2 + n: int = 121 + relative_attns: List = field(default_factory=lambda: [29, 4, 6, 8, 10, 11]) + num_attention_heads: int = 5 + + window: int = 2 + # 200 is max rna length. + # TODO: if tokenizer is overlap, then max_length should be 60 + # otherwise, will get cuda error, maybe dask can help + max_length: int = 40 + tokens_len: int = math.ceil(max_length / window) + second_input_token_len: int = 0 # is infered during runtime + vocab_size: int = 0 # is infered during runtime + second_input_vocab_size: int = 0 # is infered during runtime + tokenizer: str = ( + "overlap" # either overlap or no_overlap or overlap_multi_window + ) + + clf_target:str = 'sub_class_hico' # sub_class, major_class, sub_class_hico or major_class_hico. hico = high confidence + num_classes: int = 0 #will be infered during runtime + class_mappings:List = field(default_factory=lambda: [])#will be infered during runtime + class_weights :List = field(default_factory=lambda: []) + # how many extra window sizes other than deafault window + temperatures: List = field(default_factory=lambda: [0,10]) + + tokens_mapping_dict: Dict = None + false_input_perc:float = 0.0 + + +@dataclass +class GeneEmbeddTrainConfig: + dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv' + precursor_file_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/HBDxBase.csv' + mapping_dict_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json' + device: str = "cuda" + l2_weight_decay: float = 0.05 + batch_size: int = 512 + + batch_per_epoch:int = 0 # is infered during runtime + + label_smoothing_sim:float = 0.2 + label_smoothing_clf:float = 0.0 + # learning rate + learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to' + lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section + lr_warmup_end: float = 1 # end of linear warmup section , annealing begin + # TODO: 122 is the number of train batches per epoch, should be infered and set + # warmup batch should be during the form epoch*(train batch per epoch) + warmup_epoch: int = 10 # how many batches linear warm up for + final_epoch: int = 20 # final batch of training when want learning rate + + top_k: int = 10#int(0.1 * batch_size) # if the corresponding rna/GE appears during the top k, the correctly classified + cross_val: bool = False + labels_mapping_path: str = None + filter_seq_length:bool = False + + num_augment_exp:int = 20 + shuffle_exp: bool = False + + max_epochs: int = 3000 + + diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..3559ca664abf5a8cd92917ac27fce70dd2b18e89 --- /dev/null +++ b/environment.yml @@ -0,0 +1,24 @@ +name: transforna +channels: + - pytorch + - bioconda + - conda-forge +dependencies: + - anndata==0.8.0 + - dill==0.3.6 + - hydra-core==1.3.0 + - imbalanced-learn==0.9.1 + - matplotlib==3.5.3 + - numpy==1.22.3 + - omegaconf==2.2.2 + - pandas==1.5.2 + - plotly==5.10.0 + - PyYAML==6.0 + - rich==12.6.0 + - viennarna=2.5.0=py39h98c8e45_0 + - scanpy==1.9.1 + - scikit-learn==1.2.0 + - skorch==0.12.1 + - pytorch=1.10.1=py3.9_cuda11.3_cudnn8.2.0_0 + - tensorboard==2.11.2 + - Levenshtein==0.21.0 \ No newline at end of file diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..9f274f935f667d299a386fda503a6eb564a10a9b --- /dev/null +++ b/install.sh @@ -0,0 +1,33 @@ +#!/bin/bash +python_version="3.9" +# Initialize conda +eval "$(conda shell.bash hook)" +#print the current environment +echo "The current environment is $CONDA_DEFAULT_ENV." +while [[ "$CONDA_DEFAULT_ENV" != "base" ]]; do + conda deactivate +done + +#if conde transforna not found in the list of environments, then create the environment +if [[ $(conda env list | grep "transforna") == "" ]]; then + conda create -n transforna python=$python_version -y + conda activate transforna + conda install -c anaconda setuptools -y + + + +fi +conda activate transforna + +echo "The current environment is transforna." +pip install setuptools==59.5.0 +# Uninstall TransfoRNA using pip +pip uninstall -y TransfoRNA + +rm -rf dist TransfoRNA.egg-info + + +# Reinstall TransfoRNA using pip +python setup.py sdist +pip install dist/TransfoRNA-0.0.1.tar.gz +rm -rf TransfoRNA.egg-info dist \ No newline at end of file diff --git a/kba_pipeline/README.md b/kba_pipeline/README.md new file mode 100644 index 0000000000000000000000000000000000000000..307a7ddb4d78f2e1a0f18c77af606472210e3964 --- /dev/null +++ b/kba_pipeline/README.md @@ -0,0 +1,58 @@ +# The HBDx knowledge-based annotation (KBA) pipeline for small RNA sequences + +Most small RNA annotation tools map the sequences sequentially to different small RNA class specific reference databases, which prioritizes the distinct small RNA classes and conceals potential assignment ambiguities. The annotation strategy used here, maps the sequences to the reference sequences of all small RNA classes at the same time starting with zero mismatch tolerance. Unmapped sequences are intended to map with iterating mismatch tolerance up to three mismatches. To reduce ambiguity, sequences are first mapped to the standard coding and non-coding genes with increasing mismatch tolerance. Only then the unassigned sequences are mapped to pseudogenes in the same manner. Additionally, all small RNA sequences are checked for potential bacterial or viral origin, for genomic overlap to human transposable element loci and whether they contain potential prefixes of the 5‘ adapter. + +In cases of multiple assignments per sequence (multiple precursors could be the origin of the sequence), the ambigous annotation is resolved if + a) all assigned precursors overlap with the genomic region of the precursor with the shortest genomic context -> the subclass name of the precursor with the shortest genomic context is used OR if + b) a bin of the assigned subclass names is at the 5' or 3' end of the respective precursor -> the subclass name matching the precursor end is used. +In cases where subclass names of a) and b) are not identical, the subclass name of method a) is assigned. + + +![kba_pipeline_scheme_v05](https://github.com/gitHBDX/TransfoRNA/assets/79092907/62bf9e36-c7c7-4ff5-b747-c2c651281b42) + + +a) Schematic overview of the knowledge-based annotation (KBA) strategy applied for TransfoRNA. + +b) Schematic overview of the miRNA annotation of the custom annotation (isomiR definition based on recent miRNA research [1]). + +c) Schematic overview of the tRNA annotation of the custom annotation (inspired by UNITAS sub-classification [2]). + +d) Binning strategy used in the custom annotation for the remaining RNA major classes. The number of nucleotides per bin is constant for each precursor sequence and ranges between 20 and 39 nucleotides. Assignments are based on the bin with the highest overlap to the sequence of interest. + +e) Filtering steps that were applied to obtain the set of HICO annotations that were used for training of the TransfoRNA models. + + +## Install environment + +```bash +cd kba_pipeline +conda env create --file environment.yml +``` + +## Run annotation pipeline + +Prerequisites: +- [ ] the sequences to be annotated need to be stored as fasta format in the `kba_pipeline/data` folder +- [ ] the reference files for mapping need to be stored in the `kba_pipeline/references` folder (the required subfolders `HBDxBase`, `hg38` and `bacterial_viral` can be downloaded together with the TransfoRNA models from https://www.dropbox.com/sh/y7u8cofmg41qs0y/AADvj5lw91bx7fcDxghMbMtsa?dl=0) + +```bash +conda activate hbdx_kba +cd src +python make_anno.py --fasta_file your_sequences_to_be_annotated.fa +``` + +This script calls two major functions: +- map_2_HBDxBase: sequential mismatch mapping to HBDxBase and genome +- annotate_from_mapping: generate sequence annotation based on mapping outputs + +The main annotation file `sRNA_anno_aggregated_on_seq.csv` will be generated in the folder `outputs` + + + +## References + +[1] Tomasello, Luisa, Rosario Distefano, Giovanni Nigita, and Carlo M. Croce. 2021. “The MicroRNA Family Gets Wider: The IsomiRs Classification and Role.” Frontiers in Cell and Developmental Biology 9 (June): 1–15. https://doi.org/10.3389/fcell.2021.668648. + +[2] Gebert, Daniel, Charlotte Hewel, and David Rosenkranz. 2017. “Unitas: The Universal Tool for Annotation of Small RNAs.” BMC Genomics 18 (1): 1–14. https://doi.org/10.1186/s12864-017-4031-9. + + diff --git a/kba_pipeline/environment.yml b/kba_pipeline/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..42627f90e1cfc2839d491178c8c772d0ef386139 --- /dev/null +++ b/kba_pipeline/environment.yml @@ -0,0 +1,20 @@ +name: hbdx_kba +channels: + - bioconda + - conda-forge +dependencies: + - anndata=0.8.0 + - bedtools=2.30.0 + - biopython=1.79 + - bowtie=1.3.1 + - joblib>=1.2.0 + - pyfastx=0.8.4 + - pytest + - python=3.10.6 + - pyyaml + - rich + - samtools=1.16.1 + - tqdm + - viennarna=2.5.1 + - levenshtein + - pip \ No newline at end of file diff --git a/kba_pipeline/src/annotate_from_mapping.py b/kba_pipeline/src/annotate_from_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..f1047fcddc9c83782ea5216b80df0e97ebd97099 --- /dev/null +++ b/kba_pipeline/src/annotate_from_mapping.py @@ -0,0 +1,751 @@ +###################################################################################################### +# annotate sequences based on mapping results +###################################################################################################### +#%% +import os +import logging + +import numpy as np +import pandas as pd +from difflib import get_close_matches +from Levenshtein import distance +import json + +from joblib import Parallel, delayed +import multiprocessing + + +from utils import (fasta2df, fasta2df_subheader,log_time, reverse_complement) +from precursor_bins import get_bin_with_max_overlap + + +log = logging.getLogger(__name__) + +pd.options.mode.chained_assignment = None + + +###################################################################################################### +# paths to reference and mapping files +###################################################################################################### + +version = '_v4' + +HBDxBase_csv = f'../../references/HBDxBase/HBDxBase_all{version}.csv' +miRBase_mature_path = '../../references/HBDxBase/miRBase/mature.fa' +mat_miRNA_pos_path = '../../references/HBDxBase/miRBase/hsa_mature_position.txt' + +mapped_file = 'seqsmapped2HBDxBase_combined.txt' +unmapped_file = 'tmp_seqs3mm2HBDxBase_pseudo__unmapped.fa' +TE_file = 'tmp_seqsmapped2genome_intersect_TE.txt' +mapped_genome_file = 'seqsmapped2genome_combined.txt' +toomanyloci_genome_file = 'tmp_seqs0mm2genome__toomanyalign.fa' +unmapped_adapter_file = 'tmp_seqs3mm2adapters__unmapped.fa' +unmapped_genome_file = 'tmp_seqs0mm2genome__unmapped.fa' +unmapped_bacterial_file = 'tmp_seqs0mm2bacterial__unmapped.fa' +unmapped_viral_file = 'tmp_seqs0mm2viral__unmapped.fa' + + +sRNA_anno_file = 'sRNA_anno_from_mapping.csv' +aggreg_sRNA_anno_file = 'sRNA_anno_aggregated_on_seq.csv' + + + +#%% +###################################################################################################### +# specific functions +###################################################################################################### + +@log_time(log) +def extract_general_info(mapping_file): + # load mapping file + mapping_df = pd.read_csv(mapping_file, sep='\t', header=None) + mapping_df.columns = ['tmp_seq_id','reference','ref_start','sequence','other_alignments','mm_descriptors'] + + # add precursor length + number of bins that will be used for names + HBDxBase_df = pd.read_csv(HBDxBase_csv, index_col=0) + HBDxBase_df = HBDxBase_df[['precursor_length','precursor_bins','pseudo_class']].reset_index() + HBDxBase_df.rename(columns={'index': "reference"}, inplace=True) + mapping_df = mapping_df.merge(HBDxBase_df, left_on='reference', right_on='reference', how='left') + + # extract information + mapping_df.loc[:,'mms'] = mapping_df.mm_descriptors.fillna('').str.count('>') + mapping_df.loc[:,'mm_descriptors'] = mapping_df.mm_descriptors.str.replace(',', ';') + mapping_df.loc[:,'small_RNA_class_annotation'] = mapping_df.reference.str.split('|').str[0] + mapping_df.loc[:,'subclass_type'] = mapping_df.reference.str.split('|').str[2] + mapping_df.loc[:,'precursor_name_full'] = mapping_df.reference.str.split('|').str[1].str.split('|').str[0] + mapping_df.loc[:,'precursor_name'] = mapping_df.precursor_name_full.str.split('__').str[0].str.split('|').str[0] + mapping_df.loc[:,'seq_length'] = mapping_df.sequence.apply(lambda x: len(x)) + mapping_df.loc[:,'ref_end'] = mapping_df.ref_start + mapping_df.seq_length - 1 + mapping_df.loc[:,'mitochondrial'] = np.where(mapping_df.reference.str.contains(r'(\|MT-)|(12S)|(16S)'), 'mito', 'nuclear') + + return mapping_df + + +#%% +@log_time(log) +def tRNA_annotation(mapping_df): + """Extract tRNA specific annotation from mapping. + """ + # keep only tRNA leader/trailer with right cutting sites (+/- 5nt) + # leader + tRF_leader_df = mapping_df[mapping_df['subclass_type'] == 'leader_tRF'] + # assign as misc-leader-tRF if exceeding defined cutting site range + tRF_leader_df.loc[:,'subclass_type'] = np.where((tRF_leader_df.ref_start + tRF_leader_df.sequence.apply(lambda x: len(x))).between(45, 55, inclusive='both'), 'leader_tRF', 'misc-leader-tRF') + + # trailer + tRF_trailer_df = mapping_df[mapping_df['subclass_type'] == 'trailer_tRF'] + # assign as misc-trailer-tRF if exceeding defined cutting site range + tRF_trailer_df.loc[:,'subclass_type'] = np.where(tRF_trailer_df.ref_start.between(0, 5, inclusive='both'), 'trailer_tRF', 'misc-trailer-tRF') + + # define tRF subclasses (leader_tRF and trailer_tRF have been assigned previously) + # NOTE: allow more flexibility at ends (similar to miRNA annotation) + tRNAs_df = mapping_df[((mapping_df['small_RNA_class_annotation'] == 'tRNA') & mapping_df['subclass_type'].isna())] + tRNAs_df.loc[((tRNAs_df.ref_start < 3) & (tRNAs_df.seq_length >= 30)),'subclass_type'] = '5p-tR-half' + tRNAs_df.loc[((tRNAs_df.ref_start < 3) & (tRNAs_df.seq_length < 30)),'subclass_type'] = '5p-tRF' + tRNAs_df.loc[(((tRNAs_df.precursor_length - (tRNAs_df.ref_end + 1)) < 6) & (tRNAs_df.seq_length >= 30)),'subclass_type'] = '3p-tR-half' + tRNAs_df.loc[(((tRNAs_df.precursor_length - (tRNAs_df.ref_end + 1)).between(3,6,inclusive='neither')) & (tRNAs_df.seq_length < 30)),'subclass_type'] = '3p-tRF' + tRNAs_df.loc[(((tRNAs_df.precursor_length - (tRNAs_df.ref_end + 1)) < 3) & (tRNAs_df.seq_length < 30)),'subclass_type'] = '3p-CCA-tRF' + tRNAs_df.loc[tRNAs_df.subclass_type.isna(),'subclass_type'] = 'misc-tRF' + # add ref_iso flag + tRNAs_df['tRNA_ref_iso'] = np.where( + ( + (tRNAs_df.ref_start == 0) + | ((tRNAs_df.ref_end + 1) == tRNAs_df.precursor_length) + | ((tRNAs_df.ref_end + 1) == (tRNAs_df.precursor_length - 3)) + ), 'reftRF', 'isotRF' + ) + # concat tRNA, leader & trailer dfs + tRNAs_df = pd.concat([tRNAs_df, tRF_leader_df, tRF_trailer_df],axis=0) + # adjust precursor name and create tRNA name + tRNAs_df['precursor_name'] = tRNAs_df.precursor_name.str.extract(r"((tRNA-...-...)|(MT-..)|(tRX-...-...)|(tRNA-i...-...))", expand=True)[0] + tRNAs_df['subclass_name'] = tRNAs_df.subclass_type + '__' + tRNAs_df.precursor_name + + return tRNAs_df + +#%% +def faustrules_check(row): + """Check if isomiRs follow Faustrules (based on Tomasello et al. 2021). + """ + + # mark seqs that are not in range +/- 2nt of mature start + # check if ref_start.between(miRNAs_df.mature_start-2, miRNAs_df.mature_start+2, inclusive='both')] + ref_start = row['ref_start'] + mature_start = row['mature_start'] + + if ref_start < mature_start - 2 or ref_start > mature_start + 2: + return False + + # mark seqs with mismatch unless A>G or C>T in seed region (= position 0-8) or 3' polyA/polyT (max 3nt) + if pd.isna(row['mm_descriptors']): + return True + + seed_region_positions = set(range(9)) + non_templated_ends = {'A', 'AA', 'AAA', 'T', 'TT', 'TTT'} + + sequence = row['sequence'] + mm_descriptors = row['mm_descriptors'].split(';') + + seed_region_mismatches = 0 + three_prime_end_mismatches = 0 + + for descriptor in mm_descriptors: + pos, change = descriptor.split(':') + pos = int(pos) + original, new = change.split('>') + + if pos in seed_region_positions and (original == 'A' and new == 'G' or original == 'C' and new == 'T'): + seed_region_mismatches += 1 + + if pos >= len(sequence) - 3 and sequence[pos:] in non_templated_ends: + three_prime_end_mismatches += 1 + + total_mismatches = seed_region_mismatches + three_prime_end_mismatches + + return total_mismatches == len(mm_descriptors) + +@log_time(log) +def miRNA_annotation(mapping_df): + """Extract miRNA specific annotation from mapping. RaH Faustrules are applied. + """ + + miRNAs_df = mapping_df[mapping_df.small_RNA_class_annotation == 'miRNA'] + + nr_missing_alignments_expected = len(miRNAs_df.loc[miRNAs_df.duplicated(['tmp_seq_id','reference'], keep='first'),:]) + + # load positions of mature miRNAs within precursor + miRNA_pos_df = pd.read_csv(mat_miRNA_pos_path, sep='\t') + miRNA_pos_df.drop(columns=['precursor_length'], inplace=True) + miRNAs_df = miRNAs_df.merge(miRNA_pos_df, left_on='precursor_name_full', right_on='name_precursor', how='left') + + # load mature miRNA sequences from miRBase + miRBase_mature_df = fasta2df_subheader(miRBase_mature_path,0) + # subset to human miRNAs + miRBase_mature_df = miRBase_mature_df.loc[miRBase_mature_df.index.str.contains('hsa-'),:] + miRBase_mature_df.index = miRBase_mature_df.index.str.replace('hsa-','') + miRBase_mature_df.reset_index(inplace=True) + miRBase_mature_df.columns = ['name_mature','ref_miR_seq'] + # add 'ref_miR_seq' + miRNAs_df = miRNAs_df.merge(miRBase_mature_df, left_on='name_mature', right_on='name_mature', how='left') + + # for each duplicated tmp_seq_id/reference combi, keep the one lowest lev dist of sequence to ref_miR_seq + miRNAs_df['lev_dist'] = miRNAs_df.apply(lambda x: distance(x['sequence'], x['ref_miR_seq']), axis=1) + miRNAs_df = miRNAs_df.sort_values(by=['tmp_seq_id','lev_dist'], ascending=[True, True]).drop_duplicates(['tmp_seq_id','reference'], keep='first') + + # add ref_iso flag + miRNAs_df['miRNA_ref_iso'] = np.where( + ( + (miRNAs_df.ref_start == miRNAs_df.mature_start) + & (miRNAs_df.ref_end == miRNAs_df.mature_end) + & (miRNAs_df.mms == 0) + ), 'refmiR', 'isomiR' + ) + + # apply RaH Faustrules + miRNAs_df['faustrules_check'] = miRNAs_df.apply(faustrules_check, axis=1) + + # set miRNA_ref_iso to 'misc-miR' if faustrules_check is False + miRNAs_df.loc[~miRNAs_df.faustrules_check,'miRNA_ref_iso'] = 'misc-miR' + + # set subclass_name to name_mature if faustrules_check is True, else use precursor_name + miRNAs_df['subclass_name'] = np.where(miRNAs_df.faustrules_check, miRNAs_df.name_mature, miRNAs_df.precursor_name) + + # store name_mature for functional analysis as miRNA_names, set miR- to mir- if faustrules_check is False + miRNAs_df['miRNA_names'] = np.where(miRNAs_df.faustrules_check, miRNAs_df.name_mature, miRNAs_df.name_mature.str.replace('miR-', 'mir-')) + + # add subclass (NOTE: in cases where subclass is not part of mature name, use position relative to precursor half to define group ) + miRNAs_df['subclass_type'] = np.where(miRNAs_df.name_mature.str.endswith('5p'), '5p', np.where(miRNAs_df.name_mature.str.endswith('3p'), '3p', 'tbd')) + miRNAs_df.loc[((miRNAs_df.subclass_type == 'tbd') & (miRNAs_df.mature_start < miRNAs_df.precursor_length/2)), 'subclass_type'] = '5p' + miRNAs_df.loc[((miRNAs_df.subclass_type == 'tbd') & (miRNAs_df.mature_start >= miRNAs_df.precursor_length/2)), 'subclass_type'] = '3p' + + # subset to relevant columns + miRNAs_df = miRNAs_df[list(mapping_df.columns) + ['subclass_name','miRNA_ref_iso','miRNA_names','ref_miR_seq']] + + return miRNAs_df, nr_missing_alignments_expected + + +#%% +###################################################################################################### +# annotation of other sRNA classes +###################################################################################################### +def get_bin_with_max_overlap_parallel(df): + return df.apply(get_bin_with_max_overlap, axis=1) + +def applyParallel(df, func): + retLst = Parallel(n_jobs=multiprocessing.cpu_count())(delayed(func)(group) for group in np.array_split(df,30)) + return pd.concat(retLst) + + +@log_time(log) +def other_sRNA_annotation_new_binning(mapping_df): + """Generate subclass_name for non-tRNA/miRNA sRNAs by precursor-binning. + New binning approach: bin size is dynamically determined by the precursor length. Assignments are based on the bin with the highest overlap. + """ + + other_sRNAs_df = mapping_df[~((mapping_df.small_RNA_class_annotation == 'miRNA') | (mapping_df.small_RNA_class_annotation == 'tRNA'))] + + #create empty columns; bin start and bin end + other_sRNAs_df['bin_start'] = '' + other_sRNAs_df['bin_end'] = '' + + other_sRNAs_df = applyParallel(other_sRNAs_df, get_bin_with_max_overlap_parallel) + + return other_sRNAs_df + + +#%% +@log_time(log) +def extract_sRNA_class_specific_info(mapping_df): + tRNAs_df = tRNA_annotation(mapping_df) + miRNAs_df, nr_missing_alignments_expected = miRNA_annotation(mapping_df) + other_sRNAs_df = other_sRNA_annotation_new_binning(mapping_df) + + # add miRNA columns + tRNAs_df[['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq']] = pd.DataFrame(columns=['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq']) + other_sRNAs_df[['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq']] = pd.DataFrame(columns=['miRNA_ref_iso', 'miRNA_names', 'ref_miR_seq']) + + # re-concat sRNA class dfs + sRNA_anno_df = pd.concat([miRNAs_df, tRNAs_df, other_sRNAs_df],axis=0) + + # TEST if alignments were lost or duplicated + assert ((len(mapping_df) - nr_missing_alignments_expected) == len(sRNA_anno_df)), "alignments were lost or duplicated" + + return sRNA_anno_df + +#%% +def get_nth_nt(row): + return row['sequence'][int(row['PTM_position_in_seq'])-1] + + + +#%% +@log_time(log) +def aggregate_info_per_seq(sRNA_anno_df): + # fillna of 'subclass_name_bin_pos' with 'subclass_name' + sRNA_anno_df['subclass_name_bin_pos'] = sRNA_anno_df['subclass_name_bin_pos'].fillna(sRNA_anno_df['subclass_name']) + # get aggregated info per seq + aggreg_per_seq_df = sRNA_anno_df.groupby(['sequence']).agg({'small_RNA_class_annotation': lambda x: ';'.join(sorted(x.unique())), 'pseudo_class': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'subclass_type': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'subclass_name': lambda x: ';'.join(sorted(x.unique())), 'subclass_name_bin_pos': lambda x: ';'.join(sorted(x.unique())), 'miRNA_names': lambda x: ';'.join(x.fillna('').unique()), 'precursor_name_full': lambda x: ';'.join(sorted(x.unique())), 'mms': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'reference': lambda x: len(x), 'mitochondrial': lambda x: ';'.join(x.astype(str).sort_values(ascending=True).unique()), 'ref_miR_seq': lambda x: ';'.join(x.fillna('').unique())}) + aggreg_per_seq_df['miRNA_names'] = aggreg_per_seq_df.miRNA_names.str.replace(r';$','', regex=True) + aggreg_per_seq_df['ref_miR_seq'] = aggreg_per_seq_df.ref_miR_seq.str.replace(r';$','', regex=True) + aggreg_per_seq_df['mms'] = aggreg_per_seq_df['mms'].astype(int) + + # re-add 'miRNA_ref_iso','tRNA_ref_iso' + refmir_df = sRNA_anno_df[['sequence','miRNA_ref_iso','tRNA_ref_iso']] + refmir_df.drop_duplicates('sequence', inplace=True) + refmir_df.set_index('sequence', inplace=True) + aggreg_per_seq_df = aggreg_per_seq_df.merge(refmir_df, left_index=True, right_index=True, how='left') + + # TEST if sequences were lost + assert (len(aggreg_per_seq_df) == len(sRNA_anno_df.sequence.unique())), "sequences were lost by aggregation" + + # load unmapped seqs, if it exits + if os.path.exists(unmapped_file): + unmapped_df = fasta2df(unmapped_file) + unmapped_df = pd.DataFrame(data='no_annotation', index=unmapped_df.sequence, columns=aggreg_per_seq_df.columns) + unmapped_df['mms'] = np.nan + unmapped_df['reference'] = np.nan + unmapped_df['pseudo_class'] = True # set no annotation as pseudo_class + + # merge mapped and unmapped + annotation_df = pd.concat([aggreg_per_seq_df,unmapped_df]) + else: + annotation_df = aggreg_per_seq_df.copy() + + # load mapping to genome file + mapping_genome_df = pd.read_csv(mapped_genome_file, index_col=0, sep='\t', header=None) + mapping_genome_df.columns = ['strand','reference','ref_start','sequence','other_alignments','mm_descriptors'] + mapping_genome_df = mapping_genome_df[['strand','reference','ref_start','sequence','other_alignments']] + + # use reverse complement of 'sequence' for 'strand' == '-' + mapping_genome_df.loc[:,'sequence'] = np.where(mapping_genome_df.strand == '-', mapping_genome_df.sequence.apply(lambda x: reverse_complement(x)), mapping_genome_df.sequence) + + # get aggregated info per seq + aggreg_per_seq__genome_df = mapping_genome_df.groupby('sequence').agg({'reference': lambda x: ';'.join(sorted(x.unique())), 'other_alignments': lambda x: len(x)}) + aggreg_per_seq__genome_df['other_alignments'] = aggreg_per_seq__genome_df['other_alignments'].astype(int) + + # number of genomic loci + genomic_loci_df = pd.DataFrame(mapping_genome_df.sequence.value_counts()) + genomic_loci_df.columns = ['num_genomic_loci_maps'] + + # load too many aligments seqs + if os.path.exists(toomanyloci_genome_file): + toomanyloci_genome_df = fasta2df(toomanyloci_genome_file) + toomanyloci_genome_df = pd.DataFrame(data=101, index=toomanyloci_genome_df.sequence, columns=genomic_loci_df.columns) + else: + toomanyloci_genome_df = pd.DataFrame(columns=genomic_loci_df.columns) + + # load unmapped seqs + if os.path.exists(unmapped_genome_file): + unmapped_genome_df = fasta2df(unmapped_genome_file) + unmapped_genome_df = pd.DataFrame(data=0, index=unmapped_genome_df.sequence, columns=genomic_loci_df.columns) + else: + unmapped_genome_df = pd.DataFrame(columns=genomic_loci_df.columns) + + # concat toomanyloci, unmapped, and genomic_loci + num_genomic_loci_maps_df = pd.concat([genomic_loci_df,toomanyloci_genome_df,unmapped_genome_df]) + + # merge to annotation_df + annotation_df = annotation_df.merge(num_genomic_loci_maps_df, left_index=True, right_index=True, how='left') + annotation_df.reset_index(inplace=True) + + # add 'miRNA_seed' + annotation_df.loc[:,"miRNA_seed"] = np.where(annotation_df.small_RNA_class_annotation.str.contains('miRNA', na=False), annotation_df.sequence.str[1:9], "") + + # TEST if nan values in 'num_genomic_loci_maps' + assert (annotation_df.num_genomic_loci_maps.isna().any() == False), "nan values in 'num_genomic_loci_maps'" + + return annotation_df + + + + +#%% +@log_time(log) +def get_five_prime_adapter_info(annotation_df, five_prime_adapter): + adapter_df = pd.DataFrame(index=annotation_df.sequence) + + min_length = 6 + + is_prefixed = None + print("5' adapter affixes:") + for l in range(0, len(five_prime_adapter) - min_length): + is_prefixed_l = adapter_df.index.str.startswith(five_prime_adapter[l:]) + print(f"{five_prime_adapter[l:].ljust(30, ' ')}{is_prefixed_l.sum()}") + adapter_df.loc[adapter_df.index.str.startswith(five_prime_adapter[l:]), "five_prime_adapter_length"] = len(five_prime_adapter[l:]) + if is_prefixed is None: + is_prefixed = is_prefixed_l + else: + is_prefixed |= is_prefixed_l + + print(f"There are {is_prefixed.sum()} prefixed features.") + print("\n") + + adapter_df['five_prime_adapter_length'] = adapter_df['five_prime_adapter_length'].fillna(0) + adapter_df['five_prime_adapter_length'] = adapter_df['five_prime_adapter_length'].astype('int') + adapter_df['five_prime_adapter_filter'] = np.where(adapter_df['five_prime_adapter_length'] == 0, True, False) + adapter_df = adapter_df.reset_index() + + return adapter_df + +#%% +@log_time(log) +def reduce_ambiguity(annotation_df: pd.DataFrame) -> pd.DataFrame: + """Reduce ambiguity by + + a) using subclass_name of precursor with shortest genomic context, if all other assigned precursors overlap with its genomic region + + b) using subclass_name whose bin is at the 5' or 3' end of the precursor + + Parameters + ---------- + annotation_df : pd.DataFrame + A DataFrame containing the annotation of the sequences (var) + + Returns + ------- + pd.DataFrame + An improved version of the input DataFrame with reduced ambiguity + """ + + # extract ambigious assignments for subclass name + ambigious_matches_df = annotation_df[annotation_df.subclass_name.str.contains(';',na=False)] + if len(ambigious_matches_df) == 0: + print('No ambigious assignments for subclass name found.') + return annotation_df + clear_matches_df = annotation_df[~annotation_df.subclass_name.str.contains(';',na=False)] + + # extract required information from HBDxBase + HBDxBase_all_df = pd.read_csv(HBDxBase_csv, index_col=0) + bin_dict = HBDxBase_all_df[['precursor_name','precursor_bins']].set_index('precursor_name').to_dict()['precursor_bins'] + sRNA_class_dict = HBDxBase_all_df[['precursor_name','small_RNA_class_annotation']].set_index('precursor_name').to_dict()['small_RNA_class_annotation'] + pseudo_class_dict = HBDxBase_all_df[['precursor_name','pseudo_class']].set_index('precursor_name').to_dict()['pseudo_class'] + sc_type_dict = HBDxBase_all_df[['precursor_name','subclass_type']].set_index('precursor_name').to_dict()['subclass_type'] + genomic_context_bed = HBDxBase_all_df[['chr','start','end','precursor_name','score','strand']] + genomic_context_bed.columns = ['seq_id','start','end','name','score','strand'] + genomic_context_bed.reset_index(drop=True, inplace=True) + genomic_context_bed['genomic_length'] = genomic_context_bed.end - genomic_context_bed.start + + + def get_overlaps(genomic_context_bed: pd.DataFrame, name: str = None, complement: bool = False) -> list: + """Get genomic overlap of a given precursor name + + Parameters + ---------- + genomic_context_bed : pd.DataFrame + A DataFrame containing genomic locations of precursors in bed format + with column names: 'chr','start','end','precursor_name','score','strand' + name : str + The name of the precursor to get genomic context for + complement : bool + If True, return all precursors that do not overlap with the given precursor + + Returns + ------- + list + A list containing the precursors in the genomic (anti-)context of the given precursor + (including the precursor itself) + """ + series_OI = genomic_context_bed[genomic_context_bed['name'] == name] + start = series_OI['start'].values[0] + end = series_OI['end'].values[0] + seq_id = series_OI['seq_id'].values[0] + strand = series_OI['strand'].values[0] + + overlap_df = genomic_context_bed.copy() + + condition = (((overlap_df.start > start) & + (overlap_df.start < end)) | + ((overlap_df.end > start) & + (overlap_df.end < end)) | + ((overlap_df.start < start) & + (overlap_df.end > start)) | + ((overlap_df.start == start) & + (overlap_df.end == end)) | + ((overlap_df.start == start) & + (overlap_df.end > end)) | + ((overlap_df.start < start) & + (overlap_df.end == end))) + if not complement: + overlap_df = overlap_df[condition] + else: + overlap_df = overlap_df[~condition] + overlap_df = overlap_df[overlap_df.seq_id == seq_id] + if strand is not None: + overlap_df = overlap_df[overlap_df.strand == strand] + overlap_list = overlap_df['name'].tolist() + return overlap_list + + + def check_genomic_ctx_of_smallest_prec(precursor_name: str) -> str: + """Check for a given ambigious precursor assignment (several names separated by ';') + if all assigned precursors overlap with the genomic region + of the precursor with the shortest genomic context + + Parameters + ---------- + precursor_name: str + A string containing several precursor names separated by ';' + + Returns + ------- + str + The precursor suggested to be used instead of the multi assignment, + or None if the ambiguity could not be resolved + """ + assigned_names = precursor_name.split(';') + + tmp_genomic_context = genomic_context_bed[genomic_context_bed.name.isin(assigned_names)] + # get name of smallest genomic region + if len(tmp_genomic_context) > 0: + smallest_name = tmp_genomic_context.name[tmp_genomic_context.genomic_length.idxmin()] + # check if all assigned names are in overlap of smallest genomic region + if set(assigned_names).issubset(set(get_overlaps(genomic_context_bed,smallest_name))): + return smallest_name + else: + return None + else: + return None + + def get_subclass_name(subclass_name: str, short_prec_match_new_name: str) -> str: + """Get subclass name matching to a precursor name from a ambigious assignment (several names separated by ';') + + Parameters + ---------- + subclass_name: str + A string containing several subclass names separated by ';' + short_prec_match_new_name: str + The name of the precursor to be used instead of the multi assignment + + Returns + ------- + str + The subclass name suggested to be used instead of the multi assignment, + or None if the ambiguity could not be resolved + """ + if short_prec_match_new_name is not None: + matches = get_close_matches(short_prec_match_new_name,subclass_name.split(';'),cutoff=0.2) + if matches: + return matches[0] + else: + print(f"Could not find match for {short_prec_match_new_name} in {subclass_name}") + return subclass_name + else: + return None + + + def check_end_bins(subclass_name: str) -> str: + """Check for a given ambigious subclass name assignment (several names separated by ';') + if ambiguity can be resolved by selecting the subclass name whose bin matches the 3'/5' end of the precursor + + Parameters + ---------- + subclass_name: str + A string containing several subclass names separated by ';' + + Returns + ------- + str + The subclass name suggested to be used instead of the multi assignment, + or None if the ambiguity could not be resolved + """ + for name in subclass_name.split(';'): + if '_bin-' in name: + name_parts = name.split('_bin-') + if name_parts[0] in bin_dict and bin_dict[name_parts[0]] == int(name_parts[1]): + return name + elif int(name_parts[1]) == 1: + return name + return None + + + def adjust_4_resolved_cases(row: pd.Series) -> tuple: + """For a resolved ambiguous subclass names return adjusted values of + precursor_name_full, small_RNA_class_annotation, pseudo_class, and subclass_type + + Parameters + ---------- + row: pd.Series + A row of the var annotation containing the columns 'subclass_name', 'precursor_name_full', + 'small_RNA_class_annotation', 'pseudo_class', 'subclass_type', and 'ambiguity_resolved' + + Returns + ------- + tuple + A tuple containing the adjusted values of 'precursor_name_full', 'small_RNA_class_annotation', + 'pseudo_class', and 'subclass_type' for resolved ambiguous cases and the original values for unresolved cases + """ + if row.ambiguity_resolved: + matches_prec = get_close_matches(row.subclass_name, row.precursor_name_full.split(';'), cutoff=0.2) + if matches_prec: + return matches_prec[0], sRNA_class_dict[matches_prec[0]], pseudo_class_dict[matches_prec[0]], sc_type_dict[matches_prec[0]] + return row.precursor_name_full, row.small_RNA_class_annotation, row.pseudo_class, row.subclass_type + + + # resolve ambiguity by checking genomic context of smallest precursor + ambigious_matches_df['short_prec_match_new_name'] = ambigious_matches_df.precursor_name_full.apply(check_genomic_ctx_of_smallest_prec) + ambigious_matches_df['short_prec_match_new_name'] = ambigious_matches_df.apply(lambda x: get_subclass_name(x.subclass_name, x.short_prec_match_new_name), axis=1) + ambigious_matches_df['short_prec_match'] = ambigious_matches_df['short_prec_match_new_name'].notnull() + + # resolve ambiguity by checking if bin matches 3'/5' end of precursor + ambigious_matches_df['end_bin_match_new_name'] = ambigious_matches_df.subclass_name.apply(check_end_bins) + ambigious_matches_df['end_bin_match'] = ambigious_matches_df['end_bin_match_new_name'].notnull() + + # check if short_prec_match and end_bin_match are equal in any case + test_df = ambigious_matches_df[((ambigious_matches_df.short_prec_match == True) & (ambigious_matches_df.end_bin_match == True))] + if not (test_df.short_prec_match_new_name == test_df.end_bin_match_new_name).all(): + print('Number of cases where short_prec_match is not matching end_bin_match_new_name:',len(test_df[(test_df.short_prec_match_new_name != test_df.end_bin_match_new_name)])) + + # replace subclass_name with short_prec_match_new_name or end_bin_match_new_name + # NOTE: if short_prec_match and end_bin_match are True, short_prec_match_new_name is used + ambigious_matches_df['subclass_name'] = ambigious_matches_df.apply(lambda x: x.end_bin_match_new_name if x.end_bin_match == True else x.subclass_name, axis=1) + ambigious_matches_df['subclass_name'] = ambigious_matches_df.apply(lambda x: x.short_prec_match_new_name if x.short_prec_match == True else x.subclass_name, axis=1) + + # generate column 'ambiguity_resolved' which is True if short_prec_match and/or end_bin_match is True + ambigious_matches_df['ambiguity_resolved'] = ambigious_matches_df.short_prec_match | ambigious_matches_df.end_bin_match + print("Ambiguity resolved?\n",ambigious_matches_df.ambiguity_resolved.value_counts(normalize=True)) + + # for resolved ambiguous matches, adjust precursor_name_full, small_RNA_class_annotation, pseudo_class, subclass_type + ambigious_matches_df[['precursor_name_full','small_RNA_class_annotation','pseudo_class','subclass_type']] = ambigious_matches_df.apply(adjust_4_resolved_cases, axis=1, result_type='expand') + + # drop temporary columns + ambigious_matches_df.drop(columns=['short_prec_match_new_name','short_prec_match','end_bin_match_new_name','end_bin_match'], inplace=True) + + # concat with clear_matches_df + clear_matches_df['ambiguity_resolved'] = False + improved_annotation_df = pd.concat([clear_matches_df, ambigious_matches_df], axis=0) + improved_annotation_df = improved_annotation_df.reindex(annotation_df.index) + + return improved_annotation_df + +#%% +###################################################################################################### +# HICO (=high confidence) annotation +###################################################################################################### +@log_time(log) +def add_hico_annotation(annotation_df, five_prime_adapter): + """For miRNAs only use hico annotation if part of miRBase hico set AND refmiR + """ + + # add 'TE_annotation' + TE_df = pd.read_csv(TE_file, sep='\t', header=None, names=['sequence','TE_annotation']) + annotation_df = annotation_df.merge(TE_df, left_on='sequence', right_on='sequence', how='left') + + # add 'bacterial' mapping filter + bacterial_unmapped_df = fasta2df(unmapped_bacterial_file) + annotation_df.loc[:,'bacterial'] = np.where(annotation_df.sequence.isin(bacterial_unmapped_df.sequence), False, True) + + # add 'viral' mapping filter + viral_unmapped_df = fasta2df(unmapped_viral_file) + annotation_df.loc[:,'viral'] = np.where(annotation_df.sequence.isin(viral_unmapped_df.sequence), False, True) + + # add 'adapter_mapping_filter' column + adapter_unmapped_df = fasta2df(unmapped_adapter_file) + annotation_df.loc[:,'adapter_mapping_filter'] = np.where(annotation_df.sequence.isin(adapter_unmapped_df.sequence), True, False) + + # add filter column 'five_prime_adapter_filter' and column 'five_prime_adapter_length' indicating the length of the prefixed 5' adapter sequence + adapter_df = get_five_prime_adapter_info(annotation_df, five_prime_adapter) + annotation_df = annotation_df.merge(adapter_df, left_on='sequence', right_on='sequence', how='left') + + # apply ambiguity reduction + annotation_df = reduce_ambiguity(annotation_df) + + # add 'single_class_annotation' + annotation_df.loc[:,'single_class_annotation'] = np.where(annotation_df.small_RNA_class_annotation.str.contains(';',na=True), False, True) + + # add 'single_name_annotation' + annotation_df.loc[:,'single_name_annotation'] = np.where(annotation_df.subclass_name.str.contains(';',na=True), False, True) + + # add 'hypermapper' for sequences where more than 50 potential mapping references are recorded + annotation_df.loc[annotation_df.reference > 50,'subclass_name'] = 'hypermapper_' + annotation_df.reference.fillna(0).astype(int).astype(str) + annotation_df.loc[annotation_df.reference > 50,'subclass_name_bin_pos'] = 'hypermapper_' + annotation_df.reference.fillna(0).astype(int).astype(str) + annotation_df.loc[annotation_df.reference > 50,'precursor_name_full'] = 'hypermapper_' + annotation_df.reference.fillna(0).astype(int).astype(str) + + annotation_df.loc[:,'mitochondrial'] = np.where(annotation_df.mitochondrial.str.contains('mito',na=False), True, False) + + # add 'hico' + annotation_df.loc[:,'hico'] = np.where(( + (annotation_df.mms == 0) + & (annotation_df.single_name_annotation == True) + & (annotation_df.TE_annotation.isna() == True) + & (annotation_df.bacterial == False) + & (annotation_df.viral == False) + & (annotation_df.adapter_mapping_filter == True) + & (annotation_df.five_prime_adapter_filter == True) + ), True, False) + ## NOTE: for miRNAs only use hico annotation if part of refmiR set + annotation_df.loc[annotation_df.small_RNA_class_annotation == 'miRNA','hico'] = annotation_df.loc[annotation_df.small_RNA_class_annotation == 'miRNA','hico'] & (annotation_df.miRNA_ref_iso == 'refmiR') + + print(annotation_df[annotation_df.single_class_annotation == True].groupby('small_RNA_class_annotation').hico.value_counts()) + + return annotation_df + + + + +#%% +###################################################################################################### +# annotation pipeline +###################################################################################################### +@log_time(log) +def main(five_prime_adapter): + """Executes 'annotate_from_mapping'. + + Uses: + + - HBDxBase_csv + - miRBase_mature_path + - mat_miRNA_pos_path + + - mapping_file + - unmapped_file + - mapped_genome_file + - toomanyloci_genome_file + - unmapped_genome_file + + - TE_file + - unmapped_adapter_file + - unmapped_bacterial_file + - unmapped_viral_file + - five_prime_adapter + + """ + + + print('-------- extract general information for sequences that mapped to the HBDxBase --------') + mapped_info_df = extract_general_info(mapped_file) + print("\n") + + print('-------- extract sRNA class specific information for sequences that mapped to the HBDxBase --------') + mapped_sRNA_anno_df = extract_sRNA_class_specific_info(mapped_info_df) + + print('-------- save to file --------') + mapped_sRNA_anno_df.to_csv(sRNA_anno_file) + print("\n") + + print('-------- aggregate information for mapped and unmapped sequences (HBDxBase & human genome) --------') + sRNA_anno_per_seq_df = aggregate_info_per_seq(mapped_sRNA_anno_df) + print("\n") + + print('-------- add hico annotation (based on aggregated infos + mapping to viral/bacterial genomes + intersection with TEs) --------') + sRNA_anno_per_seq_df = add_hico_annotation(sRNA_anno_per_seq_df, five_prime_adapter) + print("\n") + + print('-------- save to file --------') + # set sequence as index again + sRNA_anno_per_seq_df.set_index('sequence', inplace=True) + sRNA_anno_per_seq_df.to_csv(aggreg_sRNA_anno_file) + print("\n") + + print('-------- generate subclass_to_annotation dict --------') + result_df = sRNA_anno_per_seq_df[['subclass_name', 'small_RNA_class_annotation']].copy() + result_df.reset_index(drop=True, inplace=True) + result_df.drop_duplicates(inplace=True) + result_df = result_df[~result_df["subclass_name"].str.contains(";")] + subclass_to_annotation = dict(zip(result_df["subclass_name"],result_df["small_RNA_class_annotation"])) + with open('subclass_to_annotation.json', 'w') as fp: + json.dump(subclass_to_annotation, fp) + + print('-------- delete tmp files --------') + os.system("rm *tmp_*") + + +#%% diff --git a/kba_pipeline/src/make_anno.py b/kba_pipeline/src/make_anno.py new file mode 100644 index 0000000000000000000000000000000000000000..c8344498d15c1a658db3fefcc48a7c482fea5ccd --- /dev/null +++ b/kba_pipeline/src/make_anno.py @@ -0,0 +1,59 @@ + +#%% +import argparse +import os +import logging + +from utils import make_output_dir,write_2_log,log_time +import map_2_HBDxBase as map_2_HBDxBase +import annotate_from_mapping as annotate_from_mapping + + +log = logging.getLogger(__name__) + + + +#%% +# get command line arguments +parser = argparse.ArgumentParser() +parser.add_argument('--five_prime_adapter', type=str, default='GTTCAGAGTTCTACAGTCCGACGATC') +parser.add_argument('--fasta_file', type=str, help="Required to provide: --fasta_file sequences_to_be_annotated.fa") # NOTE: needs to be stored in "data" folder +args = parser.parse_args() +if not args.fasta_file: + parser.print_help() + exit() +five_prime_adapter = args.five_prime_adapter +sequence_file = args.fasta_file + +#%% +@log_time(log) +def main(five_prime_adapter, sequence_file): + """Executes 'make_anno'. + 1. Maps input sequences to HBDxBase, the human genome, and a collection of viral and bacterial genomes. + 2. Extracts information from mapping files. + 3. Generates annotation columns and final annotation dataframe. + + Uses: + + - sequence_file + - five_prime_adapter + + """ + output_dir = make_output_dir(sequence_file) + os.chdir(output_dir) + + log_folder = "log" + if not os.path.exists(log_folder): + os.makedirs(log_folder) + write_2_log(f"{log_folder}/make_anno.log") + + # add name of sequence_file to log file + with open(f"{log_folder}/make_anno.log", "a") as ofile: + ofile.write(f"Sequence file: {sequence_file}\n") + + map_2_HBDxBase.main("../../data/" + sequence_file) + annotate_from_mapping.main(five_prime_adapter) + + +main(five_prime_adapter, sequence_file) +# %% diff --git a/kba_pipeline/src/map_2_HBDxBase.py b/kba_pipeline/src/map_2_HBDxBase.py new file mode 100644 index 0000000000000000000000000000000000000000..41e705660153f31c28f7f96a9c1a01b2815881a9 --- /dev/null +++ b/kba_pipeline/src/map_2_HBDxBase.py @@ -0,0 +1,318 @@ +###################################################################################################### +# map sequences to HBDxBase +###################################################################################################### +#%% +import os +import logging + +from utils import fasta2df,log_time + +log = logging.getLogger(__name__) + + +###################################################################################################### +# paths to reference files +###################################################################################################### + +version = '_v4' + +HBDxBase_index_path = f'../../references/HBDxBase/HBDxBase{version}' +HBDxBase_pseudo_index_path = f'../../references/HBDxBase/HBDxBase_pseudo{version}' +genome_index_path = '../../references/hg38/genome' +adapter_index_path = '../../references/HBDxBase/adapters' +TE_path = '../../references/hg38/TE.bed' +bacterial_index_path = '../../references/bacterial_viral/all_bacterial_refseq_with_human_host__201127.index' +viral_index_path = '../../references/bacterial_viral/viral_refseq_with_human_host__201127.index' + + + + +#%% +###################################################################################################### +# specific functions +###################################################################################################### + +@log_time(log) +def prepare_input_files(seq_input): + + # check if seq_input is path or list + if type(seq_input) == str: + # get seqs in dataset + seqs = fasta2df(seq_input) + seqs = seqs.sequence + elif type(seq_input) == list: + seqs = seq_input + else: + raise ValueError('seq_input must be either path to fasta file or list of sequences') + + # add number of sequences to log file + log_folder = "log" + with open(f"{log_folder}/make_anno.log", "a") as ofile: + ofile.write(f"KBA pipeline based on HBDxBase{version}\n") + ofile.write(f"Number of sequences to be annotated: {str(len(seqs))}\n") + + if type(seq_input) == str: + with open('seqs.fa', 'w') as ofile_1: + for i in range(len(seqs)): + ofile_1.write(">" + seqs.index[i] + "\n" + seqs[i] + "\n") + else: + with open('seqs.fa', 'w') as ofile_1: + for i in range(len(seqs)): + ofile_1.write(">seq_" + str(i) + "\n" + seqs[i] + "\n") + +@log_time(log) +def map_seq_2_HBDxBase( + number_mm, + fasta_in_file, + out_prefix +): + + bowtie_index_file = HBDxBase_index_path + + os.system( + f"bowtie -a --norc -v {number_mm} -f --suppress 2,6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \ + --al {out_prefix + str(number_mm) + 'mm2HBDxBase__mapped.fa'} \ + --un {out_prefix + str(number_mm) + 'mm2HBDxBase__unmapped.fa'} \ + {out_prefix + str(number_mm) + 'mm2HBDxBase.txt'}" + ) +@log_time(log) +def map_seq_2_HBDxBase_pseudo( + number_mm, + fasta_in_file, + out_prefix +): + + bowtie_index_file = HBDxBase_pseudo_index_path + + os.system( + f"bowtie -a --norc -v {number_mm} -f --suppress 2,6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \ + --al {out_prefix + str(number_mm) + 'mm2HBDxBase_pseudo__mapped.fa'} \ + --un {out_prefix + str(number_mm) + 'mm2HBDxBase_pseudo__unmapped.fa'} \ + {out_prefix + str(number_mm) + 'mm2HBDxBase_pseudo.txt'}" + ) + # -a Report all valid alignments per read + # --norc No mapping to reverse strand + # -v Report alignments with at most mismatches + # -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense + # -suppress Suppress columns of output in the default output mode + # -x The basename of the Bowtie, or Bowtie 2, index to be searched + +@log_time(log) +def map_seq_2_adapters( + fasta_in_file, + out_prefix +): + + bowtie_index_file = adapter_index_path + + os.system( + f"bowtie -a --best --strata --norc -v 3 -f --suppress 2,6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \ + --al {out_prefix + '3mm2adapters__mapped.fa'} \ + --un {out_prefix + '3mm2adapters__unmapped.fa'} \ + {out_prefix + '3mm2adapters.txt'}" + ) + # -a --best --strata Specifying --strata in addition to -a and --best causes bowtie to report only those alignments in the best alignment “stratum”. The alignments in the best stratum are those having the least number of mismatches + # --norc No mapping to reverse strand + # -v Report alignments with at most mismatches + # -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense + # -suppress Suppress columns of output in the default output mode + # -x The basename of the Bowtie, or Bowtie 2, index to be searched + + +@log_time(log) +def map_seq_2_genome( + fasta_in_file, + out_prefix +): + + bowtie_index_file = genome_index_path + + os.system( + f"bowtie -a -v 0 -f -m 100 --suppress 6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \ + --max {out_prefix + '0mm2genome__toomanyalign.fa'} \ + --un {out_prefix + '0mm2genome__unmapped.fa'} \ + {out_prefix + '0mm2genome.txt'}" + ) + # -a Report all valid alignments per read + # -v Report alignments with at most mismatches + # -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense + # -m Suppress all alignments for a particular read if more than reportable alignments exist for it + # -suppress Suppress columns of output in the default output mode + # -x The basename of the Bowtie, or Bowtie 2, index to be searched + + +@log_time(log) +def map_seq_2_bacterial_viral( + fasta_in_file, + out_prefix +): + + bowtie_index_file = bacterial_index_path + + os.system( + f"bowtie -a -v 0 -f -m 10 --suppress 6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \ + --al {out_prefix + '0mm2bacterial__mapped.fa'} \ + --max {out_prefix + '0mm2bacterial__toomanyalign.fa'} \ + --un {out_prefix + '0mm2bacterial__unmapped.fa'} \ + {out_prefix + '0mm2bacterial.txt'}" + ) + + + bowtie_index_file = viral_index_path + + os.system( + f"bowtie -a -v 0 -f -m 10 --suppress 6 --threads 8 -x {bowtie_index_file} {fasta_in_file} \ + --al {out_prefix + '0mm2viral__mapped.fa'} \ + --max {out_prefix + '0mm2viral__toomanyalign.fa'} \ + --un {out_prefix + '0mm2viral__unmapped.fa'} \ + {out_prefix + '0mm2viral.txt'}" + ) + # -a Report all valid alignments per read + # -v Report alignments with at most mismatches + # -f f for FASTA, -q for FASTQ; for our pipeline FASTA makes more sense + # -m Suppress all alignments for a particular read if more than reportable alignments exist for it + # -suppress Suppress columns of output in the default output mode + # -x The basename of the Bowtie, or Bowtie 2, index to be searched + + + + + +#%% +###################################################################################################### +# mapping pipeline +###################################################################################################### +@log_time(log) +def main(sequence_file): + """Executes 'map_2_HBDxBase'. Maps input sequences to HBDxBase, the human genome, and a collection of viral and bacterial genomes. + + Uses: + + - HBDxBase_index_path + - HBDxBase_pseudo_index_path + - genome_index_path + - bacterial_index_path + - viral_index_path + - sequence_file + + """ + + prepare_input_files(sequence_file) + + # sequential mm mapping to HBDxBase + print('-------- map to HBDxBase --------') + + print('-------- mapping seqs (0 mm) --------') + map_seq_2_HBDxBase( + 0, + 'seqs.fa', + 'tmp_seqs' + ) + + print('-------- mapping seqs (1 mm) --------') + map_seq_2_HBDxBase( + 1, + 'tmp_seqs0mm2HBDxBase__unmapped.fa', + 'tmp_seqs' + ) + + print('-------- mapping seqs (2 mm) --------') + map_seq_2_HBDxBase( + 2, + 'tmp_seqs1mm2HBDxBase__unmapped.fa', + 'tmp_seqs' + ) + + print('-------- mapping seqs (3 mm) --------') + map_seq_2_HBDxBase( + 3, + 'tmp_seqs2mm2HBDxBase__unmapped.fa', + 'tmp_seqs' + ) + + # sequential mm mapping to Pseudo-HBDxBase + print('-------- map to Pseudo-HBDxBase --------') + + print('-------- mapping seqs (0 mm) --------') + map_seq_2_HBDxBase_pseudo( + 0, + 'tmp_seqs3mm2HBDxBase__unmapped.fa', + 'tmp_seqs' + ) + + print('-------- mapping seqs (1 mm) --------') + map_seq_2_HBDxBase_pseudo( + 1, + 'tmp_seqs0mm2HBDxBase_pseudo__unmapped.fa', + 'tmp_seqs' + ) + + print('-------- mapping seqs (2 mm) --------') + map_seq_2_HBDxBase_pseudo( + 2, + 'tmp_seqs1mm2HBDxBase_pseudo__unmapped.fa', + 'tmp_seqs' + ) + + print('-------- mapping seqs (3 mm) --------') + map_seq_2_HBDxBase_pseudo( + 3, + 'tmp_seqs2mm2HBDxBase_pseudo__unmapped.fa', + 'tmp_seqs' + ) + + + # concatenate files + print('-------- concatenate mapping files --------') + os.system("cat tmp_seqs0mm2HBDxBase.txt tmp_seqs1mm2HBDxBase.txt tmp_seqs2mm2HBDxBase.txt tmp_seqs3mm2HBDxBase.txt tmp_seqs0mm2HBDxBase_pseudo.txt tmp_seqs1mm2HBDxBase_pseudo.txt tmp_seqs2mm2HBDxBase_pseudo.txt tmp_seqs3mm2HBDxBase_pseudo.txt > seqsmapped2HBDxBase_combined.txt") + + print('\n') + + # mapping to adapters (allowing for 3 mms) + print('-------- map to adapters (3 mm) --------') + map_seq_2_adapters( + 'seqs.fa', + 'tmp_seqs' + ) + + # mapping to genome (more than 50 alignments are not reported) + print('-------- map to human genome --------') + + print('-------- mapping seqs (0 mm) --------') + map_seq_2_genome( + 'seqs.fa', + 'tmp_seqs' + ) + + + ## concatenate files + print('-------- concatenate mapping files --------') + os.system("cp tmp_seqs0mm2genome.txt seqsmapped2genome_combined.txt") + + print('\n') + + ## intersect genome mapping hits with TE.bed + print('-------- intersect genome mapping hits with TE.bed --------') + # convert to BED format + os.system("awk 'BEGIN {FS= \"\t\"; OFS=\"\t\"} {print $3, $4, $4+length($5)-1, $5, 111, $2}' seqsmapped2genome_combined.txt > tmp_seqsmapped2genome_combined.bed") + # intersect with TE.bed (force strandedness -> fetch only sRNA_sequence and TE_name -> aggregate TE annotation on sequences) + os.system(f"bedtools intersect -a tmp_seqsmapped2genome_combined.bed -b {TE_path} -wa -wb -s" + "| awk '{print $4,$10}' | awk '{a[$1]=a[$1]\";\"$2} END {for(i in a) print i\"\t\"substr(a[i],2)}' > tmp_seqsmapped2genome_intersect_TE.txt") + + # mapping to bacterial and viral genomes (more than 10 alignments are not reported) + print('-------- map to bacterial and viral genome --------') + + print('-------- mapping seqs (0 mm) --------') + map_seq_2_bacterial_viral( + 'seqs.fa', + 'tmp_seqs' + ) + + ## concatenate files + print('-------- concatenate mapping files --------') + os.system("cat tmp_seqs0mm2bacterial.txt tmp_seqs0mm2viral.txt > seqsmapped2bacterialviral_combined.txt") + + print('\n') + + + + diff --git a/kba_pipeline/src/precursor_bins.py b/kba_pipeline/src/precursor_bins.py new file mode 100644 index 0000000000000000000000000000000000000000..c9242316ef344892426a6bc98b2a3433c1fe5895 --- /dev/null +++ b/kba_pipeline/src/precursor_bins.py @@ -0,0 +1,127 @@ +#%% +import pandas as pd +from typing import List +from collections.abc import Callable + +def load_HBDxBase(): + version = '_v4' + HBDxBase_file = f'../../references/HBDxBase/HBDxBase_all{version}.csv' + HBDxBase_df = pd.read_csv(HBDxBase_file, index_col=0) + HBDxBase_df.loc[:,'precursor_bins'] = (HBDxBase_df.precursor_length/25).astype(int) + return HBDxBase_df + +def compute_dynamic_bin_size(precursor_len:int, name:str=None, min_bin_size:int=20, max_bin_size:int=30) -> List[int]: + ''' + This function splits precursor to bins of size max_bin_size + if the last bin is smaller than min_bin_size, it will split the precursor to bins of size max_bin_size-1 + This process will continue until the last bin is larger than min_bin_size. + if the min bin size is reached and still the last bin is smaller than min_bin_size, the last two bins will be merged. + so the maximimum bin size possible would be min_bin_size+(min_bin_size-1) = 39 + ''' + def split_precursor_to_bins(precursor_len,max_bin_size): + ''' + This function splits precursor to bins of size max_bin_size + ''' + precursor_bin_lens = [] + for i in range(0, precursor_len, max_bin_size): + if i+max_bin_size < precursor_len: + precursor_bin_lens.append(max_bin_size) + else: + precursor_bin_lens.append(precursor_len-i) + return precursor_bin_lens + + if precursor_len < min_bin_size: + return [precursor_len] + else: + precursor_bin_lens = split_precursor_to_bins(precursor_len,max_bin_size) + reduced_len = max_bin_size-1 + while precursor_bin_lens[-1] < min_bin_size: + precursor_bin_lens = split_precursor_to_bins(precursor_len,reduced_len) + reduced_len -= 1 + if reduced_len < min_bin_size: + #add last two bins together + precursor_bin_lens[-2] += precursor_bin_lens[-1] + precursor_bin_lens = precursor_bin_lens[:-1] + break + + return precursor_bin_lens + +def get_bin_no_from_pos(precursor_len:int,position:int,name:str=None,min_bin_size:int=20,max_bin_size:int=30) -> int: + ''' + This function returns the bin number of a position in a precursor + bins start from 1 + ''' + precursor_bin_lens = compute_dynamic_bin_size(precursor_len=precursor_len,name=name,min_bin_size=min_bin_size,max_bin_size=max_bin_size) + bin_no = 0 + for i,bin_len in enumerate(precursor_bin_lens): + if position < bin_len: + bin_no = i + break + else: + position -= bin_len + return bin_no+1 + +def get_bin_with_max_overlap(row) -> int: + ''' + This function returns the bin number of a fragment that overlaps the most with the fragment + ''' + precursor_len = row.precursor_length + start_frag_pos = row.ref_start + frag_len = row.seq_length + name = row.precursor_name_full + min_bin_size = 20 + max_bin_size = 30 + precursor_bin_lens = compute_dynamic_bin_size(precursor_len=precursor_len,name=name,min_bin_size=min_bin_size,max_bin_size=max_bin_size) + bin_no = 0 + for i,bin_len in enumerate(precursor_bin_lens): + if start_frag_pos < bin_len: + #get overlap with curr bin + overlap = min(bin_len-start_frag_pos,frag_len) + + if overlap > frag_len/2: + bin_no = i + else: + bin_no = i+1 + break + + else: + start_frag_pos -= bin_len + #get bin start and bin end + bin_start,bin_end = sum(precursor_bin_lens[:bin_no]),sum(precursor_bin_lens[:bin_no+1]) + row['bin_start'] = bin_start + row['bin_end'] = bin_end + row['subclass_name'] = name + '_bin-' + str(bin_no+1) + row['precursor_bins'] = len(precursor_bin_lens) + row['subclass_name_bin_pos'] = name + '_binpos-' + str(bin_start) + ':' + str(bin_end) + return row + +def convert_bin_to_pos(precursor_len:int,bin_no:int,bin_function:Callable=compute_dynamic_bin_size,name:str=None,min_bin_size:int=20,max_bin_size:int=30): + ''' + This function returns the start and end position of a bin + ''' + precursor_bin_lens = bin_function(precursor_len=precursor_len,name=name,min_bin_size=min_bin_size,max_bin_size=max_bin_size) + start_pos = 0 + end_pos = 0 + for i,bin_len in enumerate(precursor_bin_lens): + if i+1 == bin_no: + end_pos = start_pos+bin_len + break + else: + start_pos += bin_len + return start_pos,end_pos + +#main +if __name__ == '__main__': + #read hbdxbase + HBDxBase_df = load_HBDxBase() + min_bin_size = 20 + max_bin_size = 30 + #select indices of precurosrs that include 'rRNA' but not 'pseudo' + rRNA_df = HBDxBase_df[HBDxBase_df.index.str.contains('rRNA') * ~HBDxBase_df.index.str.contains('pseudo')] + + #get bin of index 1 + bins = compute_dynamic_bin_size(len(rRNA_df.iloc[0].sequence),rRNA_df.iloc[0].name,min_bin_size,max_bin_size) + bin_no = get_bin_no_from_pos(len(rRNA_df.iloc[0].sequence),name=rRNA_df.iloc[0].name,position=1) + annotation_bin = get_bin_with_max_overlap(len(rRNA_df.iloc[0].sequence),start_frag_pos=1,frag_len=50,name=rRNA_df.iloc[0].name) + +# %% diff --git a/kba_pipeline/src/utils.py b/kba_pipeline/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..74ba068ac8b687628f011663ac1dcdb0b9261803 --- /dev/null +++ b/kba_pipeline/src/utils.py @@ -0,0 +1,163 @@ +import pandas as pd +import os +import errno +from pathlib import Path +from Bio.SeqIO.FastaIO import SimpleFastaParser +from datetime import datetime +from getpass import getuser + +import logging +from rich.logging import RichHandler +from functools import wraps +from time import perf_counter +from typing import Callable + +default_path = '../outputs/' + +def humanize_time(time_in_seconds: float, /) -> str: + """Return a nicely human-readable string of a time_in_seconds. + + Parameters + ---------- + time_in_seconds : float + Time in seconds, (not full seconds). + + Returns + ------- + str + A description of the time in one of the forms: + - 300.1 ms + - 4.5 sec + - 5 min 43.1 sec + """ + sgn = "" if time_in_seconds >= 0 else "- " + time_in_seconds = abs(time_in_seconds) + if time_in_seconds < 1: + return f"{sgn}{time_in_seconds*1e3:.1f} ms" + elif time_in_seconds < 60: + return f"{sgn}{time_in_seconds:.1f} sec" + else: + return f"{sgn}{int(time_in_seconds//60)} min {time_in_seconds%60:.1f} sec" + + +class log_time: + """A decorator / context manager to log the time a certain function / code block took. + + Usage either with: + + @log_time(log) + def function_getting_logged_every_time(…): + … + + producing: + + function_getting_logged_every_time took 5 sec. + + or: + + with log_time(log, "Name of this codeblock"): + … + + producing: + + Name of this codeblock took 5 sec. + """ + + def __init__(self, logger: logging.Logger, name: str = None): + """ + Parameters + ---------- + logger : logging.Logger + The logger to use for logging the time, if None use print. + name : str, optional + The name in the message, when used as a decorator this defaults to the function name, by default None + """ + self.logger = logger + self.name = name + + def __call__(self, func: Callable): + if self.name is None: + self.name = func.__qualname__ + + @wraps(func) + def inner(*args, **kwds): + with self: + return func(*args, **kwds) + + return inner + + def __enter__(self): + self.start_time = perf_counter() + + def __exit__(self, *exc): + self.exit_time = perf_counter() + + time_delta = humanize_time(self.exit_time - self.start_time) + if self.logger is None: + print(f"{self.name} took {time_delta}.") + else: + self.logger.info(f"{self.name} took {time_delta}.") + + +def write_2_log(log_file): + # Setup logging + log_file_handler = logging.FileHandler(log_file) + log_file_handler.setLevel(logging.INFO) + log_file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) + log_rich_handler = RichHandler() + log_rich_handler.setLevel(logging.INFO) #cli_args.log_level + log_rich_handler.setFormatter(logging.Formatter("%(message)s")) + logging.basicConfig(level=logging.INFO, datefmt="[%X]", handlers=[log_file_handler, log_rich_handler]) + + +def fasta2df(path): + with open(path) as fasta_file: + identifiers = [] + seqs = [] + for header, sequence in SimpleFastaParser(fasta_file): + identifiers.append(header) + seqs.append(sequence) + + fasta_df = pd.DataFrame(seqs, identifiers, columns=['sequence']) + fasta_df['sequence'] = fasta_df.sequence.apply(lambda x: x.replace('U','T')) + return fasta_df + + + +def fasta2df_subheader(path, id_pos): + with open(path) as fasta_file: + identifiers = [] + seqs = [] + for header, sequence in SimpleFastaParser(fasta_file): + identifiers.append(header.split(None)[id_pos]) + seqs.append(sequence) + + fasta_df = pd.DataFrame(seqs, identifiers, columns=['sequence']) + fasta_df['sequence'] = fasta_df.sequence.apply(lambda x: x.replace('U','T')) + return fasta_df + + + +def build_bowtie_index(bowtie_index_file): + #index_example = Path(bowtie_index_file + '.1.ebwt') + #if not index_example.is_file(): + print('-------- index is build --------') + os.system(f"bowtie-build {bowtie_index_file + '.fa'} {bowtie_index_file}") + #else: print('-------- previously built index is used --------') + + + +def make_output_dir(fasta_file): + output_dir = default_path + datetime.now().strftime('%Y-%m-%d') + ('__') + fasta_file.replace('.fasta', '').replace('.fa', '') + '/' + try: + os.makedirs(output_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise # This was not a "directory exist" error.. + return output_dir + + +def reverse_complement(seq): + complement = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'} + return ''.join([complement[base] for base in seq[::-1]]) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0b827747db832ee63e70550057b767c942362c42 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +anndata==0.8.0 +dill==0.3.6 +hydra-core==1.3.0 +imbalanced-learn==0.9.1 +matplotlib==3.5.3 +numpy==1.22.3 +omegaconf==2.2.2 +pandas==1.5.2 +plotly==5.10.0 +PyYAML==6.0 +rich==12.6.0 +viennarna==2.5.0a5 +scanpy==1.9.1 +scikit_learn==1.2.0 +skorch==0.12.1 +torch==1.10.1 +tensorboard==2.11.2 +Levenshtein==0.21.0 \ No newline at end of file diff --git a/scripts/test_inference_api.py b/scripts/test_inference_api.py new file mode 100644 index 0000000000000000000000000000000000000000..409287de67256a345b77f268363e59d32a54b3a2 --- /dev/null +++ b/scripts/test_inference_api.py @@ -0,0 +1,29 @@ +from transforna import predict_transforna, predict_transforna_all_models + +seqs = [ +'AACGAAGCTCGACTTTTAAGG', +'GTCCACCCCAAAGCGTAGG'] + +path_to_models = '/path/to/tcga/models/' +sc_preds_id_df = predict_transforna_all_models(seqs,path_to_models = path_to_models) #/models/tcga/ +#%% +#get sc predictions for models trained on id (in distribution) +sc_preds_id_df = predict_transforna(seqs, model="seq",trained_on='id',path_to_models = path_to_models) +#get sc predictions for models trained on full (all sub classes) +sc_preds_df = predict_transforna(seqs, model="seq",path_to_models = path_to_models) +#predict using models trained on major class +mc_preds_df = predict_transforna(seqs, model="seq",mc_or_sc='major_class',path_to_models = path_to_models) +#get logits +logits_df = predict_transforna(seqs, model="seq",logits_flag=True,path_to_models = path_to_models) +#get embedds +embedd_df = predict_transforna(seqs, model="seq",embedds_flag=True,path_to_models = path_to_models) +#get the top 4 similar sequences +sim_df = predict_transforna(seqs, model="seq",similarity_flag=True,n_sim=4,path_to_models = path_to_models) +#get umaps +umaps_df = predict_transforna(seqs, model="seq",umaps_flag=True,path_to_models = path_to_models) + + +all_preds_df = predict_transforna_all_models(seqs,path_to_models=path_to_models) +all_preds_df + +# %% diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..417f4c11bbfd7b895394d612b398c5d0cd08185a --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,80 @@ +#data_time for hydra output folder +get_data_time(){ + date=$(ls outputs/ | head -n 1) + time=$(ls outputs/*/ | head -n 1) + date=$date + time=$time +} + +train_model(){ + python -m transforna --config-dir="/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/conf"\ + model_name=$1 trained_on=$2 num_replicates=$4 + + get_data_time + #rename the folder to model_name + mv outputs/$date/$time outputs/$date/$3 + ls outputs/$date/ + rm -rf models/tcga/TransfoRNA_${2^^}/$5/$3 + + + + mv -f outputs/$date/$3 models/tcga/TransfoRNA_${2^^}/$5/ + rm -rf outputs/ + +} +#activate transforna environment +eval "$(conda shell.bash hook)" +conda activate transforna + +#create the models folder if it does not exist +if [[ ! -d "models/tcga/TransfoRNA_ID/major_class" ]]; then + mkdir -p models/tcga/TransfoRNA_ID/major_class +fi +if [[ ! -d "models/tcga/TransfoRNA_FULL/sub_class" ]]; then + mkdir -p models/tcga/TransfoRNA_FULL/sub_class +fi +if [[ ! -d "models/tcga/TransfoRNA_ID/sub_class" ]]; then + mkdir -p models/tcga/TransfoRNA_ID/sub_class +fi +if [[ ! -d "models/tcga/TransfoRNA_FULL/major_class" ]]; then + mkdir -p models/tcga/TransfoRNA_FULL/major_class +fi +#remove the outputs folder +rm -rf outputs + + +#define models +models=("seq" "seq-seq" "seq-rev" "seq-struct" "baseline") +models_capitalized=("Seq" "Seq-Seq" "Seq-Rev" "Seq-Struct" "Baseline") + + +num_replicates=5 + + +############train major_class_hico + +##replace clf_target:str = 'sub_class_hico' to clf_target:str = 'major_class_hico' in ../conf/train_model_configs/tcga.py +sed -i "s/clf_target:str = 'sub_class_hico'/clf_target:str = 'major_class_hico'/g" conf/train_model_configs/tcga.py +#print the file content +cat conf/train_model_configs/tcga.py +#loop and train +for i in ${!models[@]}; do + echo "Training model ${models_capitalized[$i]} for id on major_class" + train_model ${models[$i]} id ${models_capitalized[$i]} $num_replicates "major_class" + echo "Training model ${models[$i]} for full on major_class" + train_model ${models[$i]} full ${models_capitalized[$i]} 1 "major_class" +done + + +############train sub_class_hico + +#replace clf_target:str = 'major_class_hico' to clf_target:str = 'sub_class_hico' in ../conf/train_model_configs/tcga.py +sed -i "s/clf_target:str = 'major_class_hico'/clf_target:str = 'sub_class_hico'/g" conf/train_model_configs/tcga.py + +for i in ${!models[@]}; do + echo "Training model ${models_capitalized[$i]} for id on sub_class" + train_model ${models[$i]} id ${models_capitalized[$i]} $num_replicates "sub_class" + echo "Training model ${models[$i]} for full on sub_class" + train_model ${models[$i]} full ${models_capitalized[$i]} 1 "sub_class" +done + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce0ab302411aa2c397c1b90236e421c702c1af7 --- /dev/null +++ b/setup.py @@ -0,0 +1,40 @@ +from setuptools import find_packages, setup + +setup( + name='TransfoRNA', + version='0.0.1', + description='TransfoRNA: Navigating the Uncertainties of Small RNA Annotation with an Adaptive Machine Learning Strategy', + url='https://github.com/gitHBDX/TransfoRNA', + author='YasserTaha,JuliaJehn', + author_email='ytaha@hb-dx.com,jjehn@hb-dx.com,tsikosek@hb-dx.com', + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Biological Researchers', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.9', + ], + packages=find_packages(include=['transforna', 'transforna.*']), + install_requires=[ + "anndata==0.8.0", + "dill==0.3.6", + "hydra-core==1.3.0", + "imbalanced-learn==0.9.1", + "matplotlib==3.5.3", + "numpy==1.22.3", + "omegaconf==2.2.2", + "pandas==1.5.2", + "plotly==5.10.0", + "PyYAML==6.0", + "rich==12.6.0", + "viennarna>=2.5.0a5", + "scanpy==1.9.1", + "scikit_learn==1.2.0", + "skorch==0.12.1", + "torch==1.10.1", + "tensorboard==2.16.2", + "Levenshtein==0.21.0" + ], + python_requires='>=3.9', + #move yaml files to package + package_data={'': ['*.yaml']}, +) diff --git a/transforna/__init__.py b/transforna/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c10d47570d154183c8f885d991bc5330570a4641 --- /dev/null +++ b/transforna/__init__.py @@ -0,0 +1,7 @@ +from .src.callbacks import * +from .src.inference import * +from .src.model import * +from .src.novelty_prediction import * +from .src.processing import * +from .src.train import * +from .src.utils import * diff --git a/transforna/__main__.py b/transforna/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaad317cd1165a49a2ffd8635e3ca8eef678475a --- /dev/null +++ b/transforna/__main__.py @@ -0,0 +1,54 @@ +import logging +import os +import sys +import warnings + +import hydra +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig + +from transforna import compute_cv, infer_benchmark, infer_tcga, train + +warnings.filterwarnings("ignore") + + +logger = logging.getLogger(__name__) + +def add_config_to_sys_path(): + cfg = HydraConfig.get() + config_path = [path["path"] for path in cfg.runtime.config_sources if path["schema"] == "file"][0] + sys.path.append(config_path) + +#transforna could called from anywhere: +#python -m transforna --config-dir = /path/to/configs +@hydra.main(config_path='../conf', config_name="main_config") +def main(cfg: DictConfig) -> None: + add_config_to_sys_path() + #get path of hydra outputs folder + output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir + + path = os.getcwd() + #init train and model config + cfg['train_config'] = instantiate(cfg['train_config']).__dict__ + cfg['model_config'] = instantiate(cfg['model_config']).__dict__ + + #update model config with the name of the model + cfg['model_config']["model_input"] = cfg["model_name"] + + #inference or train + if cfg["inference"]: + logger.info(f"Started inference on {cfg['task']}") + if cfg['task'] == 'tcga': + return infer_tcga(cfg,path=path) + else: + return infer_benchmark(cfg,path=path) + else: + if cfg["cross_val"]: + compute_cv(cfg,path,output_dir=output_dir) + + else: + train(cfg,path=path,output_dir=output_dir) + +if __name__ == "__main__": + main() diff --git a/transforna/bin/figure_scripts/figure_4_table_3.py b/transforna/bin/figure_scripts/figure_4_table_3.py new file mode 100644 index 0000000000000000000000000000000000000000..d023d2a2676b0fedaa65c9beb82c9aca366c153b --- /dev/null +++ b/transforna/bin/figure_scripts/figure_4_table_3.py @@ -0,0 +1,173 @@ + + +#%% +#read all files ending with dist_df in bin/lc_files/ +import pandas as pd +import glob +from plotly import graph_objects as go +from transforna import load,predict_transforna +all_df = pd.DataFrame() +files = glob.glob('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_files/*lev_dist_df.csv') +for file in files: + df = pd.read_csv(file) + all_df = pd.concat([all_df,df]) +all_df = all_df.drop(columns=['Unnamed: 0']) +all_df.loc[all_df.split.isnull(),'split'] = 'NA' + +#%% +#draw a box plot for the Ensemble model for each of the splits using seaborn +ensemble_df = all_df[all_df.Model == 'Ensemble'].reset_index(drop=True) +#remove other_affixes +ensemble_df = ensemble_df[ensemble_df.split != 'other_affixes'].reset_index(drop=True) +#rename +ensemble_df['split'] = ensemble_df['split'].replace({'5\'A-affixes':'Putative 5’-adapter prefixes','Fused':'Recombined'}) +ensemble_df['split'] = ensemble_df['split'].replace({'Relaxed-miRNA':'Isomirs'}) +#%% +#plot the boxplot using seaborn +import seaborn as sns +import matplotlib.pyplot as plt +sns.set_theme(style="whitegrid") +sns.set(rc={'figure.figsize':(15,10)}) +sns.set(font_scale=1.5) +order = ['LC-familiar','LC-novel','Random','Putative 5’-adapter prefixes','Recombined','NA','LOCO','Isomirs'] +ax = sns.boxplot(x="split", y="NLD", data=ensemble_df, palette="Set3",order=order,showfliers = True) + +#add Novelty Threshold line +ax.axhline(y=ensemble_df['Novelty Threshold'].mean(), color='g', linestyle='--',xmin=0,xmax=1) +#annotate mean of Novelty Threshold +ax.annotate('NLD threshold', xy=(1.5, ensemble_df['Novelty Threshold'].mean()), xytext=(1.5, ensemble_df['Novelty Threshold'].mean()-0.07), arrowprops=dict(facecolor='black', shrink=0.05)) +#rename +ax.set_xticklabels(['LC-Familiar','LC-Novel','Random','5’-adapter artefacts','Recombined','NA','LOCO','IsomiRs']) +#add title +ax.set_facecolor('None') +plt.title('Levenshtein Distance Distribution per Split on LC') +ax.set(xlabel='Split', ylabel='Normalized Levenshtein Distance') +#save legend +plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.,facecolor=None,framealpha=0.0) +plt.savefig('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_figures/lev_dist_no_out_boxplot.svg',dpi=400) +#tilt x axis labels +plt.xticks(rotation=-22.5) +#save svg +plt.savefig('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_figures/lev_dist_seaboarn_boxplot.svg',dpi=1000) +##save png +plt.savefig('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_figures/lev_dist_seaboarn_boxplot.png',dpi=1000) +#%% +bars = [r for r in ax.get_children()] +colors = [] +for c in bars[:-1]: + try: colors.append(c.get_facecolor()) + except: pass +isomir_color = colors[len(order)-1] +isomir_color = [255*x for x in isomir_color] +#covert to rgb('r','g','b','a') +isomir_color = 'rgb(%s,%s,%s,%s)'%(isomir_color[0],isomir_color[1],isomir_color[2],isomir_color[3]) + +#%% +relaxed_mirna_df = all_df[all_df.split == 'Relaxed-miRNA'] +models = relaxed_mirna_df.Model.unique() +percentage_dict = {} +for model in models: + model_df = relaxed_mirna_df[relaxed_mirna_df.Model == model] + #compute the % of sequences with NLD < Novelty Threshold for each model + percentage_dict[model] = len(model_df[model_df['NLD'] > model_df['Novelty Threshold']])/len(model_df) + percentage_dict[model]*=100 + +fig = go.Figure() +for model in ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev','Ensemble']: + fig.add_trace(go.Bar(x=[model],y=[percentage_dict[model]],name=model,marker_color=isomir_color)) + #add percentage on top of each bar + fig.add_annotation(x=model,y=percentage_dict[model]+2,text='%s%%'%(round(percentage_dict[model],2)),showarrow=False) + #increase size of annotation + fig.update_annotations(dict(font_size=13)) +#add title in the center +fig.update_layout(title='Percentage of Isomirs considered novel per model') +fig.update_layout(xaxis_tickangle=+22.5) +fig.update_layout(showlegend=False) +#make transparent background +fig.update_layout(plot_bgcolor='rgba(0,0,0,0)') +#y axis label +fig.update_yaxes(title_text='Percentage of Novel Sequences') +#save svg +fig.show() +#save svg +#fig.write_image('relaxed_mirna_novel_perc_lc_barplot.svg') +#%% +#here we explore the false familiar of the ood lc set +ood_df = pd.read_csv('/nfs/home/yat_ldap/VS_Projects/TransfoRNA/bin/lc_files/LC-novel_lev_dist_df.csv') +mapping_dict_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v02/subclass_to_annotation.json' +mapping_dict = load(mapping_dict_path) + +LC_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v02/LC__ngs__DI_HB_GEL-23.1.2.h5ad' +ad = load(LC_path) +#%% +model = 'Ensemble' +ood_seqs = ood_df[(ood_df.Model == model).values * (ood_df['Is Familiar?'] == True).values].Sequence.tolist() +ood_predicted_labels = ood_df[(ood_df.Model == model).values * (ood_df['Is Familiar?'] == True).values].Labels.tolist() +ood_actual_labels = ad.var.loc[ood_seqs]['subclass_name'].values.tolist() +from transforna import correct_labels +ood_predicted_labels = correct_labels(ood_predicted_labels,ood_actual_labels,mapping_dict) + +#get indices where ood_predicted_labels == ood_actual_labels +correct_indices = [i for i, x in enumerate(ood_predicted_labels) if x != ood_actual_labels[i]] +#remove the indices from ood_seqs, ood_predicted_labels, ood_actual_labels +ood_seqs = [ood_seqs[i] for i in correct_indices] +ood_predicted_labels = [ood_predicted_labels[i] for i in correct_indices] +ood_actual_labels = [ood_actual_labels[i] for i in correct_indices] +#get the major class of the actual labels +ood_actual_major_class = [mapping_dict[label] if label in mapping_dict else 'None' for label in ood_actual_labels] +ood_predicted_major_class = [mapping_dict[label] if label in mapping_dict else 'None' for label in ood_predicted_labels ] +#get frequencies of each major class +from collections import Counter +ood_actual_major_class_freq = Counter(ood_actual_major_class) +ood_predicted_major_class_freq = Counter(ood_predicted_major_class) + + + +# %% +import plotly.express as px +major_classes = list(ood_actual_major_class_freq.keys()) + +ood_seqs_len = [len(seq) for seq in ood_seqs] +ood_seqs_len_freq = Counter(ood_seqs_len) +fig = px.bar(x=list(ood_seqs_len_freq.keys()),y=list(ood_seqs_len_freq.values())) +fig.show() + +#%% +import plotly.graph_objects as go +fig = go.Figure() +for major_class in major_classes: + len_dist = [len(ood_seqs[i]) for i, x in enumerate(ood_actual_major_class) if x == major_class] + len_dist_freq = Counter(len_dist) + fig.add_trace(go.Bar(x=list(len_dist_freq.keys()),y=list(len_dist_freq.values()),name=major_class)) +#stack +fig.update_layout(barmode='stack') +#make transparent background +fig.update_layout(plot_bgcolor='rgba(0,0,0,0)') +#set y axis label to Count and x axis label to Length +fig.update_layout(yaxis_title='Count',xaxis_title='Length') +#set title +fig.update_layout(title_text="Length distribution of false familiar sequences per major class") +#save as svg +fig.write_image('false_familiar_length_distribution_per_major_class_stacked.svg') +fig.show() + +# %% +#for each model, for each split, print Is Familiar? == True and print the number of sequences +for model in all_df.Model.unique(): + print('\n\n') + model_df = all_df[all_df.Model == model] + num_hicos = 0 + total_samples = 0 + for split in ['LC-familiar','LC-novel','LOCO','NA','Relaxed-miRNA']: + + split_df = model_df[model_df.split == split] + #print('Model: %s, Split: %s, Familiar: %s, Number of Sequences: %s'%(model,split,len(split_df[split_df['Is Familiar?'] == True]),len(split_df))) + #print model, split % + print('%s %s %s'%(model,split,len(split_df[split_df['Is Familiar?'] == True])/len(split_df)*100)) + if split != 'LC-novel': + num_hicos+=len(split_df[split_df['Is Familiar?'] == True]) + total_samples+=len(split_df) + #print % of hicos + print('%s %s %s'%(model,'HICO',num_hicos/total_samples*100)) + print(total_samples) +# %% diff --git a/transforna/bin/figure_scripts/figure_5_S10_table_4.py b/transforna/bin/figure_scripts/figure_5_S10_table_4.py new file mode 100644 index 0000000000000000000000000000000000000000..56cf9207efb53864f7ffbacd66b0397e8939e24a --- /dev/null +++ b/transforna/bin/figure_scripts/figure_5_S10_table_4.py @@ -0,0 +1,466 @@ +#in this file, the progression of the number of hicos per major class is computed per model +#this is done before ID, after FULL. +#%% +from transforna import load +from transforna import predict_transforna,predict_transforna_all_models +import pandas as pd +import plotly.graph_objects as go +import numpy as np + +def compute_overlap_models_ensemble(full_df:pd.DataFrame,mapping_dict:dict): + full_copy_df = full_df.copy() + full_copy_df['MC_Labels'] = full_copy_df['Net-Label'].map(mapping_dict) + #filter is familiar + full_copy_df = full_copy_df[full_copy_df['Is Familiar?']].set_index('Sequence') + #count the predicted miRNAs per each Model + full_copy_df.groupby('Model').MC_Labels.value_counts() + + #for eaach of the models and for each of the mc classes, get the overlap between the models predicting a certain mc and the ensemble having the same prediction + models = ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev'] + mcs = full_copy_df.MC_Labels.value_counts().index.tolist() + mc_stats = {} + novel_resid = {} + mcs_predicted_by_only_one_model = {} + #add all mcs as keys to mc_stats and add all models as keys in every mc + for mc in mcs: + mc_stats[mc] = {} + novel_resid[mc] = {} + mcs_predicted_by_only_one_model[mc] = {} + for model in models: + mc_stats[mc][model] = 0 + novel_resid[mc][model] = 0 + mcs_predicted_by_only_one_model[mc][model] = 0 + + for mc in mcs: + ensemble_xrna = full_copy_df[full_copy_df.Model == 'Ensemble'].iloc[full_copy_df[full_copy_df.Model == 'Ensemble'].MC_Labels.str.contains(mc).values].index.tolist() + for model in models: + model_xrna = full_copy_df[full_copy_df.Model == model].iloc[full_copy_df[full_copy_df.Model == model].MC_Labels.str.contains(mc).values].index.tolist() + common_xrna = set(ensemble_xrna).intersection(set(model_xrna)) + try: + mc_stats[mc][model] = len(common_xrna)/len(ensemble_xrna) + except ZeroDivisionError: + mc_stats[mc][model] = 0 + #check how many sequences exist in ensemble but not in model + try: + novel_resid[mc][model] = len(set(ensemble_xrna).difference(set(model_xrna)))/len(ensemble_xrna) + except ZeroDivisionError: + novel_resid[mc][model] = 0 + #check how many sequences exist in model and in ensemble but not in other models + other_models_xrna = [] + for other_model in models: + if other_model != model: + other_models_xrna.extend(full_copy_df[full_copy_df.Model == other_model].iloc[full_copy_df[full_copy_df.Model == other_model].MC_Labels.str.contains(mc).values].index.tolist()) + #check how many of model_xrna are not in other_models_xrna and are in ensemble_xrna + try: + mcs_predicted_by_only_one_model[mc][model] = len(set(model_xrna).difference(set(other_models_xrna)).intersection(set(ensemble_xrna)))/len(ensemble_xrna) + except ZeroDivisionError: + mcs_predicted_by_only_one_model[mc][model] = 0 + + return models,mc_stats,novel_resid,mcs_predicted_by_only_one_model + + +def plot_bar_overlap_models_ensemble(models,mc_stats,novel_resid,mcs_predicted_by_only_one_model): + #plot the result as bar plot per mc + import plotly.graph_objects as go + import numpy as np + import plotly.express as px + #square plot with mc classes on the x axis and the number of hicos on the y axis before ID, after ID, after FULL + #add cascaded bar plot for novel resid. one per mc per model + positions = np.arange(len(models)) + fig = go.Figure() + for model in models: + fig.add_trace(go.Bar( + x=list(mc_stats.keys()), + y=[mc_stats[mc][model] for mc in mc_stats.keys()], + name=model, + marker_color=px.colors.qualitative.Plotly[models.index(model)] + )) + + fig.add_trace(go.Bar( + x=list(mc_stats.keys()), + y=[mcs_predicted_by_only_one_model[mc][model] for mc in mc_stats.keys()], + #base = [mc_stats[mc][model] for mc in mc_stats.keys()], + name = 'novel', + marker_color='lightgrey' + )) + fig.update_layout(title='Overlap between Ensemble and other models per MC class') + + return fig + +def plot_heatmap_overlap_models_ensemble(models,mc_stats,novel_resid,mcs_predicted_by_only_one_model,what_to_plot='overlap'): + ''' + This function computes a heatmap of the overlap between the ensemble and the other models per mc class + input: + models: list of models + mc_stats: dictionary with mc classes as keys and models as keys of the inner dictionary. values represent overlap between each model and the ensemble + novel_resid: dictionary with mc classes as keys and models as keys of the inner dictionary. values represent the % of sequences that are predicted by the ensemble as familiar but with specific model as novel + mcs_predicted_by_only_one_model: dictionary with mc classes as keys and models as keys of the inner dictionary. values represent the % of sequences that are predicted as familiar by only one model + what_to_plot: string. 'overlap' for overlap between ensemble and other models, 'novel' for novel resid, 'only_one_model' for mcs predicted as novel by only one model + + ''' + + if what_to_plot == 'overlap': + plot_dict = mc_stats + elif what_to_plot == 'novel': + plot_dict = novel_resid + elif what_to_plot == 'only_one_model': + plot_dict = mcs_predicted_by_only_one_model + + import plotly.figure_factory as ff + fig = ff.create_annotated_heatmap( + z=[[plot_dict[mc][model] for mc in plot_dict.keys()] for model in models], + x=list(plot_dict.keys()), + y=models, + annotation_text=[[str(round(plot_dict[mc][model],2)) for mc in plot_dict.keys()] for model in models], + font_colors=['black'], + colorscale='Blues' + ) + #set x axis order + order_x_axis = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','lncRNA','piRNA','YRNA','vtRNA'] + fig.update_xaxes(type='category',categoryorder='array',categoryarray=order_x_axis) + + + fig.update_xaxes(side='bottom') + if what_to_plot == 'overlap': + fig.update_layout(title='Overlap between Ensemble and other models per MC class') + elif what_to_plot == 'novel': + fig.update_layout(title='Novel resid between Ensemble and other models per MC class') + elif what_to_plot == 'only_one_model': + fig.update_layout(title='MCs predicted by only one model') + return fig +#%% +#read TCGA +dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv' +models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/' +tcga_df = load(dataset_path_train) +tcga_df.set_index('sequence',inplace=True) +loco_hico_na_stats_before = {} +loco_hico_na_stats_before['HICO'] = sum(tcga_df['hico'])/tcga_df.shape[0] +before_hico_seqs = tcga_df['subclass_name'][tcga_df['hico'] == True].index.values +loco_hico_na_stats_before['LOCO'] = (sum(tcga_df.subclass_name != 'no_annotation') - sum(tcga_df['hico']))/tcga_df.shape[0] +before_loco_seqs = tcga_df[tcga_df.hico!=True][tcga_df.subclass_name != 'no_annotation'].index.values +loco_hico_na_stats_before['NA'] = sum(tcga_df.subclass_name == 'no_annotation')/tcga_df.shape[0] +before_na_seqs = tcga_df[tcga_df.subclass_name == 'no_annotation'].index.values +#load mapping dict +mapping_dict_path: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json' +mapping_dict = load(mapping_dict_path) +hico_seqs = tcga_df['subclass_name'][tcga_df['hico'] == True].index.values +hicos_mc_before_id_stats = tcga_df.loc[hico_seqs].subclass_name.map(mapping_dict).value_counts() +#remove mcs with ; in them +#hicos_mc_before_id_stats = hicos_mc_before_id_stats[~hicos_mc_before_id_stats.index.str.contains(';')] +seqs_non_hico_id = tcga_df['subclass_name'][tcga_df['hico'] == False].index.values +id_df = predict_transforna(sequences=seqs_non_hico_id,model='Seq-Rev',trained_on='id',path_to_models=models_path) +id_df = id_df[id_df['Is Familiar?']].set_index('Sequence') +#print the percentage of sequences with no_annotation and with +print('Percentage of sequences with no annotation: %s'%(id_df[id_df['Net-Label'] == 'no_annotation'].shape[0]/id_df.shape[0])) +print('Percentage of sequences with annotation: %s'%(id_df[id_df['Net-Label'] != 'no_annotation'].shape[0]/id_df.shape[0])) + +#%% +hicos_mc_after_id_stats = id_df['Net-Label'].map(mapping_dict).value_counts() +#remove mcs with ; in them +#hicos_mc_after_id_stats = hicos_mc_after_id_stats[~hicos_mc_after_id_stats.index.str.contains(';')] +#add missing major classes with zeros +for mc in hicos_mc_before_id_stats.index: + if mc not in hicos_mc_after_id_stats.index: + hicos_mc_after_id_stats[mc] = 0 +hicos_mc_after_id_stats = hicos_mc_after_id_stats+hicos_mc_before_id_stats + +#%% +seqs_non_hico_full = list(set(seqs_non_hico_id).difference(set(id_df.index.values))) +full_df = predict_transforna_all_models(sequences=seqs_non_hico_full,trained_on='full',path_to_models=models_path) +#UNCOMMENT TO COMPUTE BEFORE AND AFTER PER MC: table_4 +#ensemble_df = full_df[full_df['Model']=='Ensemble'] +#ensemble_df['Major Class'] = ensemble_df['Net-Label'].map(mapping_dict) +#new_hico_mcs= ensemble_df['Major Class'].value_counts() +#ann_hico_mcs = tcga_df[tcga_df['hico'] == True]['small_RNA_class_annotation'].value_counts() + +#%%% +inspect_model = True +if inspect_model: + #from transforna import compute_overlap_models_ensemble,plot_heatmap_overlap_models_ensemble + models, mc_stats, novel_resid, mcs_predicted_by_only_one_model = compute_overlap_models_ensemble(full_df,mapping_dict) + fig = plot_heatmap_overlap_models_ensemble(models,mc_stats,novel_resid,mcs_predicted_by_only_one_model,what_to_plot='overlap') + fig.show() + +#%% +df = full_df[full_df.Model == 'Ensemble'] +df = df[df['Is Familiar?']].set_index('Sequence') +print('Percentage of sequences with no annotation: %s'%(df[df['Is Familiar?'] == False].shape[0]/df.shape[0])) +print('Percentage of sequences with annotation: %s'%(df[df['Is Familiar?'] == True].shape[0]/df.shape[0])) +hicos_mc_after_full_stats = df['Net-Label'].map(mapping_dict).value_counts() +#remove mcs with ; in them +#hicos_mc_after_full_stats = hicos_mc_after_full_stats[~hicos_mc_after_full_stats.index.str.contains(';')] +#add missing major classes with zeros +for mc in hicos_mc_after_id_stats.index: + if mc not in hicos_mc_after_full_stats.index: + hicos_mc_after_full_stats[mc] = 0 +hicos_mc_after_full_stats = hicos_mc_after_full_stats + hicos_mc_after_id_stats + +# %% +#reorder the index of the series +hicos_mc_before_id_stats = hicos_mc_before_id_stats.reindex(hicos_mc_after_full_stats.index) +hicos_mc_after_id_stats = hicos_mc_after_id_stats.reindex(hicos_mc_after_full_stats.index) +#plot the progression of the number of hicos per major class, before ID, after ID, after FULL as a bar plot +#%% +#%% +training_mcs = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','YRNA','lncRNA'] +hicos_mc_before_id_stats_train = hicos_mc_before_id_stats[training_mcs] +hicos_mc_after_id_stats_train = hicos_mc_after_id_stats[training_mcs] +hicos_mc_after_full_stats_train = hicos_mc_after_full_stats[training_mcs] +#plot the progression of the number of hicos per major class, before ID, after ID, after FULL as a bar plot +import plotly.graph_objects as go +import numpy as np +import plotly.io as pio +import plotly.express as px + +#make a square plot with mc classes on the x axis and the number of hicos on the y axis before ID, after ID, after FULL +fig = go.Figure() +fig.add_trace(go.Bar( + x=hicos_mc_before_id_stats_train.index, + y=hicos_mc_before_id_stats_train.values, + name='Before ID', + marker_color='rgb(31, 119, 180)', + opacity = 0.5 +)) +fig.add_trace(go.Bar( + x=hicos_mc_after_id_stats_train.index, + y=hicos_mc_after_id_stats_train.values, + name='After ID', + marker_color='rgb(31, 119, 180)', + opacity=0.75 +)) +fig.add_trace(go.Bar( + x=hicos_mc_after_full_stats_train.index, + y=hicos_mc_after_full_stats_train.values, + name='After FULL', + marker_color='rgb(31, 119, 180)', + opacity=1 +)) +#make log scale +fig.update_layout( + title='Progression of the Number of HICOs per Major Class', + xaxis_tickfont_size=14, + yaxis=dict( + title='Number of HICOs', + titlefont_size=16, + tickfont_size=14, + ), + xaxis=dict( + title='Major Class', + titlefont_size=16, + tickfont_size=14, + ), + legend=dict( + x=0.8, + y=1.0, + bgcolor='rgba(255, 255, 255, 0)', + bordercolor='rgba(255, 255, 255, 0)' + ), + barmode='group', + bargap=0.15, + bargroupgap=0.1 +) +#make transparent background +fig.update_layout(plot_bgcolor='rgba(0,0,0,0)') +#log scalew +fig.update_yaxes(type="log") + +fig.update_layout(legend=dict( + yanchor="top", + y=0.99, + xanchor="left", + x=0.01 +)) +#tilt the x axis labels +fig.update_layout(xaxis_tickangle=22.5) +#set the range of the y axis +fig.update_yaxes(range=[0, 4.5]) +fig.update_layout(legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="right", + x=1 +)) +fig.write_image("progression_hicos_per_mc_train.svg") +fig.show() +#%% +eval_mcs = ['miRNA','miscRNA','piRNA','vtRNA'] +hicos_mc_before_id_stats_eval = hicos_mc_before_id_stats[eval_mcs] +hicos_mc_after_full_stats_eval = hicos_mc_after_full_stats[eval_mcs] + +hicos_mc_after_full_stats_eval.index = hicos_mc_after_full_stats_eval.index + '*' +hicos_mc_before_id_stats_eval.index = hicos_mc_before_id_stats_eval.index + '*' +#%% +#plot the progression of the number of hicos per major class, before ID, after ID, after FULL as a bar plot +import plotly.graph_objects as go +import numpy as np +import plotly.io as pio +import plotly.express as px + +fig2 = go.Figure() +fig2.add_trace(go.Bar( + x=hicos_mc_before_id_stats_eval.index, + y=hicos_mc_before_id_stats_eval.values, + name='Before ID', + marker_color='rgb(31, 119, 180)', + opacity = 0.5 +)) +fig2.add_trace(go.Bar( + x=hicos_mc_after_full_stats_eval.index, + y=hicos_mc_after_full_stats_eval.values, + name='After FULL', + marker_color='rgb(31, 119, 180)', + opacity=1 +)) +#make log scale +fig2.update_layout( + title='Progression of the Number of HICOs per Major Class', + xaxis_tickfont_size=14, + yaxis=dict( + title='Number of HICOs', + titlefont_size=16, + tickfont_size=14, + ), + xaxis=dict( + title='Major Class', + titlefont_size=16, + tickfont_size=14, + ), + legend=dict( + x=0.8, + y=1.0, + bgcolor='rgba(255, 255, 255, 0)', + bordercolor='rgba(255, 255, 255, 0)' + ), + barmode='group', + bargap=0.15, + bargroupgap=0.1 +) +#make transparent background +fig2.update_layout(plot_bgcolor='rgba(0,0,0,0)') +#log scalew +fig2.update_yaxes(type="log") + +fig2.update_layout(legend=dict( + yanchor="top", + y=0.99, + xanchor="left", + x=0.01 +)) +#tilt the x axis labels +fig2.update_layout(xaxis_tickangle=22.5) +#set the range of the y axis +fig2.update_yaxes(range=[0, 4.5]) +#adjust bargap +fig2.update_layout(bargap=0.3) +fig2.update_layout(legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="right", + x=1 +)) +#fig2.write_image("progression_hicos_per_mc_eval.svg") +fig2.show() +# %% +#append df and df_after_id +df_all_hico = df.append(id_df) +loco_hico_na_stats_after = {} +loco_hico_na_stats_after['HICO from NA'] = sum(df_all_hico.index.isin(before_na_seqs))/tcga_df.shape[0] +loco_pred_df = df_all_hico[df_all_hico.index.isin(before_loco_seqs)] +loco_anns_pd = tcga_df.loc[loco_pred_df.index].subclass_name.str.split(';',expand=True) +loco_anns_pd = loco_anns_pd.apply(lambda x: x.str.lower()) +#duplicate labels in loco_pred_df * times as the num of columns in loco_anns_pd +loco_pred_labels_df = pd.DataFrame(np.repeat(loco_pred_df['Net-Label'].values,loco_anns_pd.shape[1]).reshape(loco_pred_df.shape[0],loco_anns_pd.shape[1])).set_index(loco_pred_df.index) +loco_pred_labels_df = loco_pred_labels_df.apply(lambda x: x.str.lower()) + + + +#%% +trna_mask_df = loco_pred_labels_df.apply(lambda x: x.str.contains('_trna')).any(axis=1) +trna_loco_pred_df = loco_pred_labels_df[trna_mask_df] +#get trna_loco_anns_pd +trna_loco_anns_pd = loco_anns_pd[trna_mask_df] +#for trna_loco_pred_df, remove what prepends the __ and what appends the last - +trna_loco_pred_df = trna_loco_pred_df.apply(lambda x: x.str.split('__').str[1]) +trna_loco_pred_df = trna_loco_pred_df.apply(lambda x: x.str.split('-').str[:-1].str.join('-')) +#compute overlap between trna_loco_pred_df and trna_loco_anns_pd +#for every value in trna_loco_pred_df, check if is part of the corresponding position in trna_loco_anns_pd +num_hico_trna_from_loco = 0 +for idx,row in trna_loco_pred_df.iterrows(): + trna_label = row[0] + num_hico_trna_from_loco += trna_loco_anns_pd.loc[idx].apply(lambda x: x!=None and trna_label in x).any() + + +#%% +#check if 'mir' or 'let' is in any of the values per row. the columns are numbered from 0 to len(loco_anns_pd.columns) +mir_mask_df = loco_pred_labels_df.apply(lambda x: x.str.contains('mir')).any(axis=1) +let_mask_df = loco_pred_labels_df.apply(lambda x: x.str.contains('let')).any(axis=1) +mir_or_let_mask_df = mir_mask_df | let_mask_df +mir_or_let_loco_pred_df = loco_pred_labels_df[mir_or_let_mask_df] +mir_or_let_loco_anns_pd = loco_anns_pd[mir_or_let_mask_df] +#for each value in mir_or_let_loco_pred_df, if the value contains two '-', remove the last one and what comes after it +mir_or_let_loco_anns_pd = mir_or_let_loco_anns_pd.applymap(lambda x: '-'.join(x.split('-')[:-1]) if x!=None and x.count('-') == 2 else x) +mir_or_let_loco_pred_df = mir_or_let_loco_pred_df.applymap(lambda x: '-'.join(x.split('-')[:-1]) if x!=None and x.count('-') == 2 else x) +#compute overlap between mir_or_let_loco_pred_df and mir_or_let_loco_anns_pd +num_hico_mir_from_loco = sum((mir_or_let_loco_anns_pd == mir_or_let_loco_pred_df).any(axis=1)) +#%% + + +#get rest_loco_anns_pd +rest_loco_pred_df = loco_pred_labels_df[~mir_or_let_mask_df & ~trna_mask_df] +rest_loco_anns_pd = loco_anns_pd[~mir_or_let_mask_df & ~trna_mask_df] + +num_hico_bins_from_loco = 0 +for idx,row in rest_loco_pred_df.iterrows(): + rest_rna_label = row[0].split('-')[0] + try: + bin_no = int(row[0].split('-')[1]) + except: + continue + + num_hico_bins_from_loco += rest_loco_anns_pd.loc[idx].apply(lambda x: x!=None and rest_rna_label == x.split('-')[0] and abs(int(x.split('-')[1])- bin_no)<=1).any() + +loco_hico_na_stats_after['HICO from LOCO'] = (num_hico_trna_from_loco + num_hico_mir_from_loco + num_hico_bins_from_loco)/tcga_df.shape[0] +loco_hico_na_stats_after['LOCO from NA'] = loco_hico_na_stats_before['NA'] - loco_hico_na_stats_after['HICO from NA'] +loco_hico_na_stats_after['LOCO from LOCO'] = loco_hico_na_stats_before['LOCO'] - loco_hico_na_stats_after['HICO from LOCO'] +loco_hico_na_stats_after['HICO'] = loco_hico_na_stats_before['HICO'] + +# %% + +import plotly.graph_objects as go +import plotly.io as pio +import plotly.express as px + +color_mapping = {} +for key in loco_hico_na_stats_before.keys(): + if key.startswith('HICO'): + color_mapping[key] = "rgb(51,160,44)" + elif key.startswith('LOCO'): + color_mapping[key] = "rgb(178,223,138)" + else: + color_mapping[key] = "rgb(251,154,153)" +colors = list(color_mapping.values()) +fig = go.Figure(data=[go.Pie(labels=list(loco_hico_na_stats_before.keys()), values=list(loco_hico_na_stats_before.values()),hole=.0,marker=dict(colors=colors),sort=False)]) +fig.update_layout(title='Percentage of HICOs, LOCOs and NAs before ID') +fig.show() +#save figure as svg +#fig.write_image("pie_chart_before_id.svg") + +# %% + +color_mapping = {} +for key in loco_hico_na_stats_after.keys(): + if key.startswith('HICO'): + color_mapping[key] = "rgb(51,160,44)" + elif key.startswith('LOCO'): + color_mapping[key] = "rgb(178,223,138)" + +loco_hico_na_stats_after = {k: loco_hico_na_stats_after[k] for k in sorted(loco_hico_na_stats_after, key=lambda k: k.startswith('HICO'), reverse=True)} + +fig = go.Figure(data=[go.Pie(labels=list(loco_hico_na_stats_after.keys()), values=list(loco_hico_na_stats_after.values()),hole=.0,marker=dict(colors=list(color_mapping.values())),sort=False)]) +fig.update_layout(title='Percentage of HICOs, LOCOs and NAs after ID') +fig.show() +#save figure as svg +#fig.write_image("pie_chart_after_id.svg") diff --git a/transforna/bin/figure_scripts/figure_6.ipynb b/transforna/bin/figure_scripts/figure_6.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..70460f8402e76a33c3022f7e75d4bc85979a80d8 --- /dev/null +++ b/transforna/bin/figure_scripts/figure_6.ipynb @@ -0,0 +1,228 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/nfs/home/yat_ldap/conda/envs/hbdx/envs/transforna/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from transforna import load,predict_transforna_all_models,predict_transforna,fold_sequences\n", + "models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/'\n", + "lc_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv'\n", + "tcga_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv'\n", + "\n", + "tcga_df = load(tcga_path)\n", + "lc_df = load(lc_path)\n", + "\n", + "lc_df = lc_df[lc_df.sequence.str.len() <= 30]\n", + "\n", + "all_seqs = lc_df.sequence.tolist()+tcga_df.sequence.tolist()\n", + "\n", + "mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json'\n", + "mapping_dict = load(mapping_dict_path)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions = predict_transforna_all_models(all_seqs,trained_on='full',path_to_models=models_path)\n", + "predictions.to_csv('predictions_lc_tcga.csv',index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "#read predictions\n", + "predictions = load('predictions_lc_tcga.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "umaps = {}\n", + "models = predictions['Model'].unique()\n", + "for model in models:\n", + " if model == 'Ensemble':\n", + " continue\n", + " #get predictions\n", + " model_predictions = predictions[predictions['Model']==model]\n", + " #get is familiar rows\n", + " familiar_df = model_predictions[model_predictions['Is Familiar?']==True]\n", + " #get umap\n", + " umap_df = predict_transforna(model_predictions['Sequence'].tolist(),model=model,trained_on='full',path_to_models=models_path,umap_flag=True)\n", + " umaps[model] = umap_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import plotly.express as px\n", + "import numpy as np\n", + "mcs = np.unique(umaps['Seq']['Net-Label'].map(mapping_dict))\n", + "#filter out the classes that contain ;\n", + "mcs = [mc for mc in mcs if ';' not in mc]\n", + "colors = px.colors.qualitative.Plotly\n", + "color_mapping = dict(zip(mcs,colors))\n", + "for model,umap_df in umaps.items():\n", + " umap_df['Major Class'] = umap_df['Net-Label'].map(mapping_dict)\n", + " umap_df_copy = umap_df.copy()\n", + " #remove rows with Major Class containing ;\n", + " umap_df = umap_df[~umap_df['Major Class'].str.contains(';')]\n", + " fig = px.scatter(umap_df,x='UMAP1',y='UMAP2',color='Major Class',hover_data\n", + " =['Sequence'],title=model,\\\n", + " width = 800, height=800,color_discrete_map=color_mapping)\n", + " fig.update_traces(marker=dict(size=1))\n", + " #white background\n", + " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", + " #only show UMAP1 from 4.3 to 11\n", + " fig.update_xaxes(range=[4.3,11])\n", + " #and UMAP2 from -2.3 to 6.8\n", + " fig.update_yaxes(range=[-2.3,6.8])\n", + " #fig.show()\n", + " fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.png')\n", + " fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.svg')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.express as px\n", + "import numpy as np\n", + "mcs = np.unique(umaps['Seq']['Net-Label'].map(mapping_dict))\n", + "#filter out the classes that contain ;\n", + "mcs = [mc for mc in mcs if ';' not in mc]\n", + "colors = px.colors.qualitative.Plotly + px.colors.qualitative.Light24\n", + "color_mapping = dict(zip(mcs,colors))\n", + "for model,umap_df in umaps.items():\n", + " umap_df['Major Class'] = umap_df['Net-Label'].map(mapping_dict)\n", + " umap_df_copy = umap_df.copy()\n", + " #remove rows with Major Class containing ;\n", + " umap_df = umap_df[~umap_df['Major Class'].str.contains(';')]\n", + " fig = px.scatter(umap_df,x='UMAP1',y='UMAP2',color='Major Class',hover_data\n", + " =['Sequence'],title=model,\\\n", + " width = 800, height=800,color_discrete_map=color_mapping)\n", + " fig.update_traces(marker=dict(size=1))\n", + " #white background\n", + " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", + " #fig.show()\n", + " fig.write_image(f'lc_figures/lc_tcga_umap_{model}.png')\n", + " fig.write_image(f'lc_figures/lc_tcga_umap_{model}.svg')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#plot umap using px.scatter for each model\n", + "import plotly.express as px\n", + "import numpy as np\n", + "mcs = np.unique(umaps['Seq']['Net-Label'].map(mapping_dict))\n", + "#filter out the classes that contain ;\n", + "mcs = [mc for mc in mcs if ';' not in mc]\n", + "colors = px.colors.qualitative.Plotly\n", + "color_mapping = dict(zip(mcs,colors))\n", + "umap_df = umaps['Seq']\n", + "umap_df['Major Class'] = umap_df['Net-Label'].map(mapping_dict)\n", + "umap_df_copy = umap_df.copy()\n", + "#display points contained within the circle at center (7.9,2.5) and radius 4.3\n", + "umap_df_copy['distance'] = np.sqrt((umap_df_copy['UMAP1']-7.9)**2+(umap_df_copy['UMAP2']-2.5)**2)\n", + "umap_df_copy = umap_df_copy[umap_df_copy['distance']<=4.3]\n", + "#remove rows with Major Class containing ;\n", + "umap_df_copy = umap_df_copy[~umap_df_copy['Major Class'].str.contains(';')]\n", + "fig = px.scatter(umap_df_copy,x='UMAP1',y='UMAP2',color='Major Class',hover_data\n", + " =['Sequence'],title=model,\\\n", + " width = 800, height=800,color_discrete_map=color_mapping)\n", + "fig.update_traces(marker=dict(size=1))\n", + "#white background\n", + "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", + "fig.show()\n", + "#fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.png')\n", + "#fig.write_image(f'lc_figures/lc_tcga_umap_selected_{model}.svg')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#plot\n", + "sec_struct = fold_sequences(model_predictions['Sequence'].tolist())['structure_37']\n", + "#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n", + "sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "umap_df = umaps['Seq-Struct']\n", + "fig = px.scatter(umap_df,x='UMAP1',y='UMAP2',color=sec_struct_ratio,hover_data=['Sequence'],title=model,\\\n", + " width = 800, height=800,color_continuous_scale='Viridis')\n", + "fig.update_traces(marker=dict(size=1))\n", + "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", + "#save\n", + "fig.write_image(f'lc_figures/lc_tcga_umap_{model}_dot_bracket.png')\n", + "fig.write_image(f'lc_figures/lc_tcga_umap_{model}_dot_bracket.svg')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "transforna", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/transforna/bin/figure_scripts/figure_S4.ipynb b/transforna/bin/figure_scripts/figure_S4.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..11b8ba8e0293ba6067b4b6fe10fb99b2c9dfce20 --- /dev/null +++ b/transforna/bin/figure_scripts/figure_S4.ipynb @@ -0,0 +1,368 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import pandas as pd\n", + "scores = {'major_class':{},'sub_class':{}}\n", + "models = ['Baseline','Seq','Seq-Seq','Seq-Struct','Seq-Rev']\n", + "models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID'\n", + "for model1 in models:\n", + " summary_pd = pd.read_csv(models_path+'/major_class/'+model1+'/summary_pd.tsv',sep='\\t')\n", + " scores['major_class'][model1] = str(summary_pd['B. Acc'].mean()*100)+'+/-'+' ('+str(summary_pd['B. Acc'].std()*100)+')'\n", + " summary_pd = pd.read_csv(models_path+'/sub_class/'+model1+'/summary_pd.tsv',sep='\\t')\n", + " scores['sub_class'][model1] = str(summary_pd['B. Acc'].mean()*100)+'+/-'+' ('+str(summary_pd['B. Acc'].std()*100) +')'" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Baseline': '52.83789870060305+/- (1.0961119898709506)',\n", + " 'Seq': '97.70018230805728+/- (0.3819207447704567)',\n", + " 'Seq-Seq': '95.65091330992355+/- (0.4963151975035616)',\n", + " 'Seq-Struct': '97.71071590680333+/- (0.6173598637101496)',\n", + " 'Seq-Rev': '97.51224133899979+/- (0.3418133671042992)'}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores['sub_class']" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import json\n", + "import pandas as pd\n", + "with open('/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json') as f:\n", + " mapping_dict = json.load(f)\n", + "\n", + "b_acc_sc_to_mc = {}\n", + "for model1 in models:\n", + " b_acc = []\n", + " for idx in range(5):\n", + " confusion_matrix = pd.read_csv(models_path+'/sub_class/'+model1+f'/embedds/confusion_matrix_{idx}.csv',sep=',',index_col=0)\n", + " confusion_matrix.index = confusion_matrix.index.map(mapping_dict)\n", + " confusion_matrix.columns = confusion_matrix.columns.map(mapping_dict)\n", + " confusion_matrix = confusion_matrix.groupby(confusion_matrix.index).sum().groupby(confusion_matrix.columns,axis=1).sum()\n", + " b_acc.append(confusion_matrix.values.diagonal().sum()/confusion_matrix.values.sum())\n", + " b_acc_sc_to_mc[model1] = str(pd.Series(b_acc).mean()*100)+'+/-'+' ('+str(pd.Series(b_acc).std()*100)+')'\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Baseline': '89.6182558114013+/- (0.6372156071358975)',\n", + " 'Seq': '99.66714304286457+/- (0.1404591049684126)',\n", + " 'Seq-Seq': '99.40702944026852+/- (0.18268320317601783)',\n", + " 'Seq-Struct': '99.77114728744993+/- (0.06976258667467564)',\n", + " 'Seq-Rev': '99.70878801385821+/- (0.11954774341354062)'}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b_acc_sc_to_mc" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import plotly.express as px\n", + "no_annotation_predictions = {}\n", + "for model1 in models:\n", + " #multiindex\n", + " no_annotation_predictions[model1] = pd.read_csv(models_path+'/sub_class/'+model1+'/embedds/no_annotation_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", + " no_annotation_predictions[model1].set_index([('RNA Sequences','0')] ,inplace=True)\n", + " no_annotation_predictions[model1].index.name = 'RNA Sequences'\n", + " no_annotation_predictions[model1] = no_annotation_predictions[model1]['Logits'].idxmax(axis=1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transforna.src.utils.tcga_post_analysis_utils import correct_labels\n", + "import pandas as pd\n", + "correlation = pd.DataFrame(index=models,columns=models)\n", + "for model1 in models:\n", + " for model2 in models:\n", + " model1_predictions = correct_labels(no_annotation_predictions[model1],no_annotation_predictions[model2],mapping_dict)\n", + " is_equal = model1_predictions == no_annotation_predictions[model2].values\n", + " correlation.loc[model1,model2] = is_equal.sum()/len(is_equal)\n", + "font_size = 20\n", + "fig = px.imshow(correlation, color_continuous_scale='Blues')\n", + "#annotate\n", + "for i in range(len(models)):\n", + " for j in range(len(models)):\n", + " if i != j:\n", + " font = dict(color='black', size=font_size)\n", + " else:\n", + " font = dict(color='white', size=font_size) \n", + " \n", + " fig.add_annotation(\n", + " x=j, y=i,\n", + " text=str(round(correlation.iloc[i,j],2)),\n", + " showarrow=False,\n", + " font=font\n", + " )\n", + "\n", + "#set figure size: width and height\n", + "fig.update_layout(width=800, height=800)\n", + "\n", + "fig.update_layout(title='Correlation between models for each sub_class model')\n", + "#set x and y axis to Models\n", + "fig.update_xaxes(title_text='Models', tickfont=dict(size=font_size))\n", + "fig.update_yaxes(title_text='Models', tickfont=dict(size=font_size))\n", + "fig.show()\n", + "#save\n", + "fig.write_image('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/figures/correlation_id_models_sub_class.png')\n", + "fig.write_image('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/figures/correlation_id_models_sub_class.svg')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "#create umap for every model from embedds folder\n", + "models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID'\n", + "\n", + "#read\n", + "sc_embedds = {}\n", + "mc_embedds = {}\n", + "sc_to_mc_labels = {}\n", + "sc_labels = {}\n", + "mc_labels = {}\n", + "for model in models:\n", + " df = pd.read_csv(models_path+'/sub_class/'+model+'/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", + " sc_embedds[model] = df['RNA Embedds'].values\n", + " sc_labels[model] = df['Labels']['0']\n", + " sc_to_mc_labels[model] = sc_labels[model].map(mapping_dict).values\n", + "\n", + " #major class\n", + " df = pd.read_csv(models_path+'/major_class/'+model+'/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", + " mc_embedds[model] = df['RNA Embedds'].values\n", + " mc_labels[model] = df['Labels']['0']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import umap\n", + "#compute umap coordinates\n", + "sc_umap_coords = {}\n", + "mc_umap_coords = {}\n", + "for model in models:\n", + " sc_umap_coords[model] = umap.UMAP(n_neighbors=5, min_dist=0.3, n_components=2, metric='euclidean').fit_transform(sc_embedds[model])\n", + " mc_umap_coords[model] = umap.UMAP(n_neighbors=5, min_dist=0.3, n_components=2, metric='euclidean').fit_transform(mc_embedds[model])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#plot umap\n", + "import plotly.express as px\n", + "import numpy as np\n", + "\n", + "mcs = np.unique(sc_to_mc_labels[models[0]])\n", + "colors = px.colors.qualitative.Plotly\n", + "color_mapping = dict(zip(mcs,colors))\n", + "for model in models:\n", + " fig = px.scatter(x=sc_umap_coords[model][:,0],y=sc_umap_coords[model][:,1],color=sc_to_mc_labels[model],labels={'color':'Major Class'},title=model, width=800, height=800,\\\n", + "\n", + " hover_data={'Major Class':sc_labels[model],'Sub Class':sc_to_mc_labels[model]},color_discrete_map=color_mapping)\n", + "\n", + " fig.update_traces(marker=dict(size=1))\n", + " #white background\n", + " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", + "\n", + " fig.write_image(models_path+'/sub_class/'+model+'/figures/sc_umap.svg')\n", + " fig.write_image(models_path+'/sub_class/'+model+'/figures/sc_umap.png')\n", + " fig.show()\n", + "\n", + " #plot umap for major class\n", + " fig = px.scatter(x=mc_umap_coords[model][:,0],y=mc_umap_coords[model][:,1],color=mc_labels[model],labels={'color':'Major Class'},title=model, width=800, height=800,\\\n", + "\n", + " hover_data={'Major Class':mc_labels[model]},color_discrete_map=color_mapping)\n", + " fig.update_traces(marker=dict(size=1))\n", + " #white background\n", + " fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", + "\n", + " fig.write_image(models_path+'/major_class/'+model+'/figures/mc_umap.svg')\n", + " fig.write_image(models_path+'/major_class/'+model+'/figures/mc_umap.png')\n", + " fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transforna import fold_sequences\n", + "df = pd.read_csv(models_path+'/major_class/Seq-Struct/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", + "sec_struct = fold_sequences(df['RNA Sequences']['0'])['structure_37']\n", + "#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n", + "sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n", + "fig = px.scatter(x=mc_umap_coords['Seq-Struct'][:,0],y=mc_umap_coords['Seq-Struct'][:,1],color=sec_struct_ratio,labels={'color':'Base Pairing'},title='Seq-Struct', width=800, height=800,\\\n", + " hover_data={'Major Class':mc_labels['Seq-Struct']}, color_continuous_scale='Viridis',range_color=[0,1])\n", + "\n", + "fig.update_traces(marker=dict(size=3))\n", + "#white background\n", + "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", + "fig.show()\n", + "fig.write_image(models_path+'/major_class/Seq-Struct/figures/mc_umap_sec_struct.svg')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from transforna import fold_sequences\n", + "df = pd.read_csv(models_path+'/sub_class/Seq-Struct/embedds/train_embedds.tsv',sep='\\t',header=[0,1],index_col=[0])\n", + "sec_struct = fold_sequences(df['RNA Sequences']['0'])['structure_37']\n", + "#sec struct ratio is calculated as the number of non '.' characters divided by the length of the sequence\n", + "sec_struct_ratio = sec_struct.apply(lambda x: (len(x)-x.count('.'))/len(x))\n", + "fig = px.scatter(x=sc_umap_coords['Seq-Struct'][:,0],y=sc_umap_coords['Seq-Struct'][:,1],color=sec_struct_ratio,labels={'color':'Base Pairing'},title='Seq-Struct', width=800, height=800,\\\n", + " hover_data={'Major Class':mc_labels['Seq-Struct']}, color_continuous_scale='Viridis',range_color=[0,1])\n", + "\n", + "fig.update_traces(marker=dict(size=3))\n", + "#white background\n", + "fig.update_layout(plot_bgcolor='rgba(0,0,0,0)')\n", + "fig.show()\n", + "fig.write_image(models_path+'/sub_class/Seq-Struct/figures/sc_umap_sec_struct.svg')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "from transforna import Results_Handler,get_closest_ngbr_per_split\n", + "\n", + "splits = ['train','valid','test','ood','artificial','no_annotation']\n", + "splits_to_plot = ['test','ood','random','recombined','artificial_affix']\n", + "renaming_dict= {'test':'ID (test)','ood':'Rare sub-classes','random':'Random','artificial_affix':'Putative 5\\'-adapter prefixes','recombined':'Recombined'}\n", + "\n", + "lev_dist_df = pd.DataFrame()\n", + "for model in models:\n", + " results = Results_Handler(models_path+f'/sub_class/{model}/embedds',splits=splits,read_dataset=True)\n", + " results.append_loco_variants()\n", + " results.get_knn_model()\n", + " \n", + " #compute levenstein distance per split\n", + " for split in splits_to_plot:\n", + " split_seqs,split_labels,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,split)\n", + " #create df from split and levenstein distance\n", + " lev_dist_split_df = pd.DataFrame({'split':split,'lev_dist':lev_dist,'seqs':split_seqs,'labels':split_labels,'top_n_seqs':top_n_seqs,'top_n_labels':top_n_labels})\n", + " #rename \n", + " lev_dist_split_df['split'] = lev_dist_split_df['split'].map(renaming_dict)\n", + " lev_dist_split_df['model'] = model\n", + " #append \n", + " lev_dist_df = pd.concat([lev_dist_df,lev_dist_split_df],axis=0)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#plot the distribution of lev_dist for each split for each model\n", + "model_thresholds = {'Baseline':0.267,'Seq':0.246,'Seq-Seq':0.272,'Seq-Struct': 0.242,'Seq-Rev':0.237}\n", + "model_aucs = {'Baseline':0.76,'Seq':0.97,'Seq-Seq':0.96,'Seq-Struct': 0.97,'Seq-Rev':0.97}\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "sns.set_theme(style=\"whitegrid\")\n", + "sns.set(rc={'figure.figsize':(15,10)})\n", + "sns.set(font_scale=1.5)\n", + "ax = sns.boxplot(x=\"model\", y=\"lev_dist\", hue=\"split\", data=lev_dist_df, palette=\"Set3\",order=models,showfliers = True)\n", + "#add title\n", + "ax.set_facecolor('None')\n", + "plt.title('Levenshtein Distance Distribution per Model on ID')\n", + "ax.set(xlabel='Model', ylabel='Normalized Levenshtein Distance')\n", + "#legend background should transparent\n", + "ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.,facecolor=None,framealpha=0.0)\n", + "# add horizontal lines for thresholds for each model while making sure the line is within the boxplot\n", + "min_val = 0 \n", + "for model in models:\n", + " thresh = model_thresholds[model]\n", + " plt.axhline(y=thresh, color='g', linestyle='--',xmin=min_val,xmax=min_val+0.2)\n", + " min_val+=0.2\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "transforna", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/transforna/bin/figure_scripts/figure_S5.py b/transforna/bin/figure_scripts/figure_S5.py new file mode 100644 index 0000000000000000000000000000000000000000..a32c49f32d6fee6202df81adef6c85e5fa81235a --- /dev/null +++ b/transforna/bin/figure_scripts/figure_S5.py @@ -0,0 +1,94 @@ +#%% +import numpy as np +import pandas as pd + +import transforna +from transforna import IDModelAugmenter, load + +#%% +model_name = 'Seq' +config_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_ID/sub_class/{model_name}/meta/hp_settings.yaml' +config = load(config_path) +model_augmenter = IDModelAugmenter(df=None,config=config) +df = model_augmenter.predict_transforna_na() +tcga_df = load('/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv') + +tcga_df.set_index('sequence',inplace=True) +tcga_df['Labels'] = tcga_df['subclass_name'][tcga_df['hico'] == True] +tcga_df['Labels'] = tcga_df['Labels'].astype('category') +#%% +tcga_df.loc[df.Sequence.values,'Labels'] = df['Net-Label'].values + +loco_labels_df = tcga_df['subclass_name'].str.split(';',expand=True).loc[df['Sequence']] +#filter the rows having no_annotation in the first row of loco_labels_df +loco_labels_df = loco_labels_df.iloc[~(loco_labels_df[0] == 'no_annotation').values] +#%% +#get the Is Familiar? column from df based on index of loco_labels_df +novelty_prediction_loco_df = df[df['Sequence'].isin(loco_labels_df.index)].set_index('Sequence')['Is Familiar?'] +#%% +id_predictions_df = tcga_df.loc[loco_labels_df.index]['Labels'] +#copy the columns of id_predictions_df nuber of times equal to the number of columns in loco_labels_df +id_predictions_df = pd.concat([id_predictions_df]*loco_labels_df.shape[1],axis=1) +id_predictions_df.columns = np.arange(loco_labels_df.shape[1]) +equ_mask = loco_labels_df == id_predictions_df +#check how many rows in eq_mask has atleast one True +num_true = equ_mask.any(axis=1).sum() +print('percentage of all loco RNAs: ',num_true/equ_mask.shape[0]) + + +#split loco_labels_df into two dataframes. familiar and novel +fam_loco_labels_df = loco_labels_df[novelty_prediction_loco_df] +novel_loco_labels__df = loco_labels_df[~novelty_prediction_loco_df] +#seperate id_predictions_df into two dataframes. novel and familiar +id_predictions_fam_df = id_predictions_df[novelty_prediction_loco_df] +id_predictions_novel_df = id_predictions_df[~novelty_prediction_loco_df] +#%% +num_true_fam = (fam_loco_labels_df == id_predictions_fam_df).any(axis=1).sum() +num_true_novel = (novel_loco_labels__df == id_predictions_novel_df).any(axis=1).sum() + +print('percentage of similar predictions in familiar: ',num_true_fam/fam_loco_labels_df.shape[0]) +print('percentage of similar predictions not in novel: ',num_true_novel/novel_loco_labels__df.shape[0]) +print('') +# %% +#remove the rows in fam_loco_labels_df and id_predictions_fam_df that have atleast one True in equ_mask +fam_loco_labels_no_overlap_df = fam_loco_labels_df[~equ_mask.any(axis=1)] +id_predictions_fam_no_overlap_df = id_predictions_fam_df[~equ_mask.any(axis=1)] +#collapse the dataframe of fam_loco_labels_df with a ';' seperator +collapsed_loco_labels_df = fam_loco_labels_no_overlap_df.apply(lambda x: ';'.join(x.dropna().astype(str)),axis=1) +#combined collapsed_loco_labels_df with id_predictions_fam_df[0] +predicted_fam_but_ann_novel_df = pd.concat([collapsed_loco_labels_df,id_predictions_fam_no_overlap_df[0]],axis=1) +#rename columns +predicted_fam_but_ann_novel_df.columns = ['KBA_labels','predicted_label'] +# %% +#get major class for each column in KBA_labels and predicted_label +mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json' +sc_to_mc_mapper_dict = load(mapping_dict_path) + +predicted_fam_but_ann_novel_df['KBA_labels_mc'] = predicted_fam_but_ann_novel_df['KBA_labels'].str.split(';').apply(lambda x: ';'.join([sc_to_mc_mapper_dict[i] if i in sc_to_mc_mapper_dict.keys() else i for i in x])) +predicted_fam_but_ann_novel_df['predicted_label_mc'] = predicted_fam_but_ann_novel_df['predicted_label'].apply(lambda x: sc_to_mc_mapper_dict[x] if x in sc_to_mc_mapper_dict.keys() else x) +# %% +#for the each of the sequence in predicted_fam_but_ann_novel_df, compute the sim seq along with the lv distance +from transforna import predict_transforna + +sim_df = predict_transforna(model=model_name,sequences=predicted_fam_but_ann_novel_df.index.tolist(),similarity_flag=True,n_sim=1,trained_on='id',path_to_models='/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/') +sim_df = sim_df.set_index('Sequence') + +#append the sim_df to predicted_fam_but_ann_novel_df except for the Labels column +predicted_fam_but_ann_novel_df = pd.concat([predicted_fam_but_ann_novel_df,sim_df.drop('Labels',axis=1)],axis=1) +# %% +#plot the mc proportions of predicted_label_mc +predicted_fam_but_ann_novel_df['predicted_label_mc'].value_counts().plot(kind='bar') +#get order of labels on x axis +x_labels = predicted_fam_but_ann_novel_df['predicted_label_mc'].value_counts().index.tolist() +# %% +#plot the LV distance per predicted_label_mc and order the x axis based on the order of x_labels +fig = predicted_fam_but_ann_novel_df.boxplot(column='NLD',by='predicted_label_mc',figsize=(20,10),rot=90,showfliers=False) +#reorder x axis in fig by x_labels +fig.set_xticklabels(x_labels) +#increase font of axis labels and ticks +fig.set_xlabel('Predicted Label',fontsize=20) +fig.set_ylabel('Levenstein Distance',fontsize=20) +fig.tick_params(axis='both', which='major', labelsize=20) +#display pandas full rows +pd.set_option('display.max_rows', None) +# %% diff --git a/transforna/bin/figure_scripts/figure_S8.py b/transforna/bin/figure_scripts/figure_S8.py new file mode 100644 index 0000000000000000000000000000000000000000..efcffd322de7abbdc20afea20c271801eaf799ec --- /dev/null +++ b/transforna/bin/figure_scripts/figure_S8.py @@ -0,0 +1,56 @@ +#%% +from transforna import load +from transforna import Results_Handler +import pandas as pd +import plotly.graph_objects as go +import plotly.io as pio +path = '/media/ftp_share/hbdx/analysis/tcga/TransfoRNA_I_ID_V4/sub_class/Seq/embedds/' +results:Results_Handler = Results_Handler(path=path,splits=['train','valid','test','ood'],read_ad=True) + +mapping_dict = load('/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v02/subclass_to_annotation.json') +mapping_dict['artificial_affix'] = 'artificial_affix' +train_df = results.splits_df_dict['train_df'] +valid_df = results.splits_df_dict['valid_df'] +test_df = results.splits_df_dict['test_df'] +ood_df = results.splits_df_dict['ood_df'] +#remove RNA Sequences from the dataframe if not in results.ad.var.index +train_df = train_df[train_df['RNA Sequences'].isin(results.ad.var[results.ad.var['hico'] == True].index)['0'].values] +valid_df = valid_df[valid_df['RNA Sequences'].isin(results.ad.var.index[results.ad.var['hico'] == True])['0'].values] +test_df = test_df[test_df['RNA Sequences'].isin(results.ad.var.index[results.ad.var['hico'] == True])['0'].values] +ood_df = ood_df[ood_df['RNA Sequences'].isin(results.ad.var.index[results.ad.var['hico'] == True])['0'].values] +#concatenate train,valid and test +train_val_test_df = train_df.append(valid_df).append(test_df) +#map Labels to annotation +hico_id_labels = train_val_test_df['Labels','0'].map(mapping_dict).values +hico_ood_labels = ood_df['Labels','0'].map(mapping_dict).values +#read ad +ad = results.ad + +hico_loco_df = pd.DataFrame(columns=['mc','hico_id','hico_ood','loco']) +for mc in ad.var['small_RNA_class_annotation'][ad.var['hico'] == True].unique(): + hico_loco_df = hico_loco_df.append({'mc':mc, + 'hico_id':sum([mc in i for i in hico_id_labels]), + 'hico_ood':sum([mc in i for i in hico_ood_labels]), + 'loco':sum([mc in i for i in ad.var['small_RNA_class_annotation'][ad.var['hico'] != True].values])},ignore_index=True) +# %% + + +order = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','miscRNA','lncRNA','piRNA','YRNA','vtRNA'] + +fig = go.Figure() +fig.add_trace(go.Bar(x=hico_loco_df['mc'],y=hico_loco_df['hico_id'],name='HICO ID',marker_color='#00CC96')) +fig.add_trace(go.Bar(x=hico_loco_df['mc'],y=hico_loco_df['hico_ood'],name='HICO OOD',marker_color='darkcyan')) +fig.add_trace(go.Bar(x=hico_loco_df['mc'],y=hico_loco_df['loco'],name='LOCO',marker_color='#7f7f7f')) +fig.update_layout(barmode='group') +fig.update_layout(width=800,height=800) +#order the x axis +fig.update_layout(xaxis={'categoryorder':'array','categoryarray':order}) +fig.update_layout(xaxis_title='Major Class',yaxis_title='Number of Sequences') +fig.update_layout(title='Number of Sequences per Major Class in ID, OOD and LOCO') +fig.update_layout(yaxis_type="log") +#save as png +pio.write_image(fig,'hico_id_ood_loco_proportion.png') +#save as svg +pio.write_image(fig,'hico_id_ood_loco_proportion.svg') +fig.show() +# %% diff --git a/transforna/bin/figure_scripts/figure_S9_S11.py b/transforna/bin/figure_scripts/figure_S9_S11.py new file mode 100644 index 0000000000000000000000000000000000000000..556b7181aedee9474a30472ffc51d9645e848df4 --- /dev/null +++ b/transforna/bin/figure_scripts/figure_S9_S11.py @@ -0,0 +1,136 @@ +#in this file, the progression of the number of hicos per major class is computed per model +#this is done before ID, after FULL. +#%% +from transforna import load +from transforna import predict_transforna,predict_transforna_all_models +import pandas as pd +import plotly.graph_objects as go +import numpy as np +import plotly.io as pio +import plotly.express as px +mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/subclass_to_annotation.json' +models_path = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/' + +mapping_dict = load(mapping_dict_path) + +#%% +dataset:str = 'LC' +hico_loco_na_flag:str = 'hico' +assert hico_loco_na_flag in ['hico','loco_na'], 'hico_loco_na_flag must be either hico or loco_na' +if dataset == 'TCGA': + dataset_path_train: str = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA/data/TCGA__ngs__miRNA_log2RPM-24.04.0__var.csv' +else: + dataset_path_train: str = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv' + +prediction_single_pd = predict_transforna(['AAAAAAACCCCCTTTTTTT'],model='Seq',logits_flag = True,trained_on='id',path_to_models=models_path) +sub_classes_used_for_training = prediction_single_pd.columns.tolist() + +var = load(dataset_path_train).set_index('sequence') +#remove from var all indexes that are longer than 30 +var = var[var.index.str.len() <= 30] +hico_seqs_all = var.index[var['hico']].tolist() +hico_labels_all = var['subclass_name'][var['hico']].values + +hico_seqs_id = var.index[var.hico & var.subclass_name.isin(sub_classes_used_for_training)].tolist() +hico_labels_id = var['subclass_name'][var.hico & var.subclass_name.isin(sub_classes_used_for_training)].values + +non_hico_seqs = var['subclass_name'][var['hico'] == False].index.values +non_hico_labels = var['subclass_name'][var['hico'] == False].values + + +#filter hico labels and hico seqs to hico ID +if hico_loco_na_flag == 'loco_na': + curr_seqs = non_hico_seqs + curr_labels = non_hico_labels +elif hico_loco_na_flag == 'hico': + curr_seqs = hico_seqs_id + curr_labels = hico_labels_id + +full_df = predict_transforna_all_models(sequences=curr_seqs,path_to_models=models_path) + + +#%% +mcs = ['rRNA','tRNA','snoRNA','protein_coding','snRNA','miRNA','miscRNA','lncRNA','piRNA','YRNA','vtRNA'] +#for each mc, get the sequences of hicos in that mc and compute the number of hicos per model +num_hicos_per_mc = {} +if hico_loco_na_flag == 'hico':#this is where ground truth exists (hico id) + curr_labels_id_mc = [mapping_dict[label] for label in curr_labels] + +elif hico_loco_na_flag == 'loco_na': # this is where ground truth does not exist (LOCO/NA) + ensemble_preds = full_df[full_df.Model == 'Ensemble'].set_index('Sequence').loc[curr_seqs].reset_index() + curr_labels_id_mc = [mapping_dict[label] for label in ensemble_preds['Net-Label']] + +for mc in mcs: + #select sequences from hico_seqs that are in the major class mc + mc_seqs = [seq for seq,label in zip(curr_seqs,curr_labels_id_mc) if label == mc] + if len(mc_seqs) == 0: + num_hicos_per_mc[mc] = {model:0 for model in full_df.Model.unique()} + continue + #only keep in full_df the sequences that are in mc_seqs + mc_full_df = full_df[full_df.Sequence.isin(mc_seqs)] + curr_num_hico_per_model = mc_full_df[mc_full_df['Is Familiar?']].groupby(['Model'])['Is Familiar?'].value_counts().droplevel(1) + #remove Baseline from index + curr_num_hico_per_model = curr_num_hico_per_model.drop('Baseline') + curr_num_hico_per_model -= curr_num_hico_per_model.min() + num_hicos_per_mc[mc] = curr_num_hico_per_model.to_dict() +#%% +to_plot_df = pd.DataFrame(num_hicos_per_mc) +to_plot_mcs = ['rRNA','tRNA','snoRNA'] +fig = go.Figure() +#x axis should be the mcs +for model in num_hicos_per_mc['rRNA'].keys(): + fig.add_trace(go.Bar(x=mcs, y=[num_hicos_per_mc[mc][model] for mc in mcs], name=model)) + +fig.update_layout(barmode='group') +fig.update_layout(plot_bgcolor='rgba(0,0,0,0)') +fig.write_image(f"num_hicos_per_model_{dataset}_{hico_loco_na_flag}.svg") +fig.update_yaxes(type="log") +fig.show() + +#%% + +import pandas as pd +import glob +from plotly import graph_objects as go +from transforna import load,predict_transforna +all_df = pd.DataFrame() +files = glob.glob('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/bin/lc_files/LC-novel_lev_dist_df.csv') +for file in files: + df = pd.read_csv(file) + all_df = pd.concat([all_df,df]) +all_df = all_df.drop(columns=['Unnamed: 0']) +all_df.loc[all_df.split.isnull(),'split'] = 'NA' +ensemble_df = all_df[all_df.Model == 'Ensemble'] +# %% + +lc_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv' +lc_df = load(lc_path) +lc_df.set_index('sequence',inplace=True) +# %% +#filter lc_df to only include sequences that are in ensemble_df +lc_df = lc_df.loc[ensemble_df.Sequence] +actual_major_classes = lc_df['small_RNA_class_annotation'] +predicted_major_classes = ensemble_df[['Net-Label','Sequence']].set_index('Sequence').loc[lc_df.index]['Net-Label'].map(mapping_dict) +# %% +#plot correlation matrix between actual and predicted major classes +from sklearn.metrics import confusion_matrix +import seaborn as sns +import matplotlib.pyplot as plt +import numpy as np +major_classes = list(set(list(predicted_major_classes.unique())+list(actual_major_classes.unique()))) +conf_matrix = confusion_matrix(actual_major_classes,predicted_major_classes,labels=major_classes) +conf_matrix = conf_matrix/np.sum(conf_matrix,axis=1) + +sns.heatmap(conf_matrix,annot=True,cmap='Blues') +for i in range(conf_matrix.shape[0]): + for j in range(conf_matrix.shape[1]): + conf_matrix[i,j] = round(conf_matrix[i,j],1) + + +plt.xlabel('Predicted Major Class') +plt.ylabel('Actual Major Class') +plt.xticks(np.arange(len(major_classes)),major_classes,rotation=90) +plt.yticks(np.arange(len(major_classes)),major_classes,rotation=0) +plt.show() + +# %% diff --git a/transforna/bin/figure_scripts/infer_lc_using_tcga.py b/transforna/bin/figure_scripts/infer_lc_using_tcga.py new file mode 100644 index 0000000000000000000000000000000000000000..c1ac5405e30d66b06d2b8aad1085d987262e1584 --- /dev/null +++ b/transforna/bin/figure_scripts/infer_lc_using_tcga.py @@ -0,0 +1,438 @@ +import random +import sys +from random import randint + +import pandas as pd +import plotly.graph_objects as go +from anndata import AnnData + +#add parent directory to path +sys.path.append('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/') +from src import (Results_Handler, correct_labels, load, predict_transforna, + predict_transforna_all_models,get_fused_seqs) + + +def get_mc_sc(infer_df,sequences,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag = False): + + infered_seqs = infer_df.loc[sequences] + sc_classes_df = infered_seqs['subclass_name'].str.split(';',expand=True) + #filter rows with all nans in sc_classes_df + sc_classes_df = sc_classes_df[~sc_classes_df.isnull().all(axis=1)] + #cmask for classes used for training + if ood_flag: + sub_classes_used_for_training_plus_neighbors = [] + #for every subclass in sub_classes_used_for_training that contains bin, get previous and succeeding bins + for sub_class in sub_classes_used_for_training: + sub_classes_used_for_training_plus_neighbors.append(sub_class) + if 'bin' in sub_class: + bin_num = int(sub_class.split('_bin-')[1]) + if bin_num > 0: + sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num-1}') + sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num+1}') + if 'tR' in sub_class: + #seperate the first part(either 3p/5p), also ge tthe second part after __ + first_part = sub_class.split('-')[0] + second_part = sub_class.split('__')[1] + #get all classes in sc_to_mc_mapper_dict,values that contain both parts and append them to sub_classes_used_for_training_plus_neighbors + sub_classes_used_for_training_plus_neighbors += [sc for sc in sc_to_mc_mapper_dict.keys() if first_part in sc and second_part in sc] + sub_classes_used_for_training_plus_neighbors = list(set(sub_classes_used_for_training_plus_neighbors)) + mask = sc_classes_df.applymap(lambda x: True if (x not in sub_classes_used_for_training_plus_neighbors and 'hypermapper' not in x)\ + or pd.isnull(x) else False) + + else: + mask = sc_classes_df.applymap(lambda x: True if x in sub_classes_used_for_training or pd.isnull(x) else False) + + #check if any sub class in sub_classes_used_for_training is in sc_classes_df + if mask.apply(lambda x: all(x.tolist()), axis=1).sum() == 0: + #TODO: change to log + import logging + log_ = logging.getLogger(__name__) + log_.error('None of the sub classes used for training are in the sequences') + raise Exception('None of the sub classes used for training are in the sequences') + + #filter rows with atleast one False in mask + sc_classes_df = sc_classes_df[mask.apply(lambda x: all(x.tolist()), axis=1)] + #get mc classes + mc_classes_df = sc_classes_df.applymap(lambda x: sc_to_mc_mapper_dict[x] if x in sc_to_mc_mapper_dict else 'not_found') + #assign major class for not found values if containing 'miRNA', 'tRNA', 'rRNA', 'snRNA', 'snoRNA' + #mc_classes_df = mc_classes_df.applymap(lambda x: None if x is None else 'miRNA' if 'miR' in x else 'tRNA' if 'tRNA' in x else 'rRNA' if 'rRNA' in x else 'snRNA' if 'snRNA' in x else 'snoRNA' if 'snoRNA' in x else 'snoRNA' if 'SNO' in x else 'protein_coding' if 'RPL37A' in x else 'lncRNA' if 'SNHG1' in x else 'not_found') + #filter all 'not_found' rows + mc_classes_df = mc_classes_df[mc_classes_df.apply(lambda x: 'not_found' not in x.tolist() ,axis=1)] + #filter values with ; in mc_classes_df + mc_classes_df = mc_classes_df[~mc_classes_df[0].str.contains(';')] + #filter index + sc_classes_df = sc_classes_df.loc[mc_classes_df.index] + mc_classes_df = mc_classes_df.loc[sc_classes_df.index] + return mc_classes_df,sc_classes_df + +def plot_confusion_false_novel(df,sc_df,mc_df,save_figs:bool=False): + #filter index of sc_classes_df to contain indices of outliers df + curr_sc_classes_df = sc_df.loc[[i for i in df.index if i in sc_df.index]] + curr_mc_classes_df = mc_df.loc[[i for i in df.index if i in mc_df.index]] + #convert Labels to mc_Labels + df = df.assign(predicted_mc_labels=df.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1)) + #add mc classes + df = df.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist()) + #add sc classes + df = df.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist()) + #compute accuracy + df = df.assign(mc_accuracy=df.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1)) + df = df.assign(sc_accuracy=df.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1)) + + #use plotly to plot confusion matrix based on mc classes + mc_confusion_matrix = df.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack() + mc_confusion_matrix = mc_confusion_matrix.fillna(0) + mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1) + mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,2)) + #for columns not in rows, sum the column values and add them to a new column called 'other' + other_col = [0]*mc_confusion_matrix.shape[0] + for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]: + other_col += mc_confusion_matrix[i] + mc_confusion_matrix['other'] = other_col + #add an other row with all zeros + mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1] + #drop all columns not in rows + mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1) + #plot confusion matri + fig = go.Figure(data=go.Heatmap( + z=mc_confusion_matrix.values, + x=mc_confusion_matrix.columns, + y=mc_confusion_matrix.index, + hoverongaps = False)) + #add z values to heatmap + for i in range(len(mc_confusion_matrix.index)): + for j in range(len(mc_confusion_matrix.columns)): + fig.add_annotation(text=str(mc_confusion_matrix.values[i][j]), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i], + showarrow=False, font_size=25, font_color='red') + #add title + fig.update_layout(title_text='Confusion matrix based on mc classes for false novel sequences') + #label x axis and y axis + fig.update_xaxes(title_text='Predicted mc class') + fig.update_yaxes(title_text='Actual mc class') + #save + if save_figs: + fig.write_image('transforna/bin/lc_figures/confusion_matrix_mc_classes_false_novel.png') + + +def compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers = False,fig_prefix:str = '',save_figs:bool=False): + font_size = 25 + if fig_prefix == 'LC-familiar': + font_size = 10 + #rename Labels to predicted_sc_labels + prediction_pd = prediction_pd.rename(columns={'Net-Label':'predicted_sc_labels'}) + + for model in prediction_pd['Model'].unique(): + #get model predictions + num_rows = sc_classes_df.shape[0] + model_prediction_pd = prediction_pd[prediction_pd['Model'] == model] + model_prediction_pd = model_prediction_pd.set_index('Sequence') + #filter index of model_prediction_pd to contain indices of sc_classes_df + model_prediction_pd = model_prediction_pd.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]] + + try: #try because ensemble models do not have a folder + #check how many of the hico seqs exist in the train_df + embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/{model}/embedds' + results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train']) + except: + embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/Seq-Rev/embedds' + results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train']) + + train_seqs = set(results.splits_df_dict['train_df']['RNA Sequences']['0'].values.tolist()) + common_seqs = train_seqs.intersection(set(model_prediction_pd.index.tolist())) + print(f'Number of common seqs between train_df and {model} is {len(common_seqs)}') + #print(f'removing overlaping sequences between train set and inference') + #remove common_seqs from model_prediction_pd + #model_prediction_pd = model_prediction_pd.drop(common_seqs) + + + #compute number of sequences where NLD is higher than Novelty Threshold + num_outliers = sum(model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold']) + false_novel_df = model_prediction_pd[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold']] + + plot_confusion_false_novel(false_novel_df,sc_classes_df,mc_classes_df,save_figs) + #draw a pie chart depicting number of outliers per actual_mc_labels + fig_outl = mc_classes_df.loc[false_novel_df.index][0].value_counts().plot.pie(autopct='%1.1f%%',figsize=(6, 6)) + fig_outl.set_title(f'False Novel per MC for {model}: {num_outliers}') + if save_figs: + fig_outl.get_figure().savefig(f'transforna/bin/lc_figures/false_novel_mc_{model}.png') + fig_outl.get_figure().clf() + #get number of unique sub classes per major class in false_novel_df + false_novel_sc_freq_df = sc_classes_df.loc[false_novel_df.index][0].value_counts().to_frame() + #save index as csv + #false_novel_sc_freq_df.to_csv(f'false_novel_sc_freq_df_{model}.csv') + #add mc to false_novel_sc_freq_df + false_novel_sc_freq_df['MC'] = false_novel_sc_freq_df.index.map(lambda x: sc_to_mc_mapper_dict[x]) + #plot pie chart showing unique sub classes per major class in false_novel_df + fig_outl_sc = false_novel_sc_freq_df.groupby('MC')[0].sum().plot.pie(autopct='%1.1f%%',figsize=(6, 6)) + fig_outl_sc.set_title(f'False novel: No. Unique sub classes per MC {model}: {num_outliers}') + if save_figs: + fig_outl_sc.get_figure().savefig(f'transforna/bin/lc_figures/{fig_prefix}_false_novel_sc_{model}.png') + fig_outl_sc.get_figure().clf() + #filter outliers + if seperate_outliers: + model_prediction_pd = model_prediction_pd[model_prediction_pd['NLD'] <= model_prediction_pd['Novelty Threshold']] + else: + #set the predictions of outliers to 'Outlier' + model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_sc_labels'] = 'Outlier' + model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_mc_labels'] = 'Outlier' + sc_to_mc_mapper_dict['Outlier'] = 'Outlier' + + #filter index of sc_classes_df to contain indices of model_prediction_pd + curr_sc_classes_df = sc_classes_df.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]] + curr_mc_classes_df = mc_classes_df.loc[[i for i in model_prediction_pd.index if i in mc_classes_df.index]] + #convert Labels to mc_Labels + model_prediction_pd = model_prediction_pd.assign(predicted_mc_labels=model_prediction_pd.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1)) + #add mc classes + model_prediction_pd = model_prediction_pd.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist()) + #add sc classes + model_prediction_pd = model_prediction_pd.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist()) + #correct labels + model_prediction_pd['predicted_sc_labels'] = correct_labels(model_prediction_pd['predicted_sc_labels'],model_prediction_pd['actual_sc_labels'],sc_to_mc_mapper_dict) + #compute accuracy + model_prediction_pd = model_prediction_pd.assign(mc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1)) + model_prediction_pd = model_prediction_pd.assign(sc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1)) + + if not seperate_outliers: + cols_to_save = ['actual_mc_labels','predicted_mc_labels','predicted_sc_labels','actual_sc_labels'] + total_false_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels != model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save] + #add a column indicating if NLD is greater than Novelty Threshold + total_false_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_false_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_false_mc_predictions_df.index]['Novelty Threshold'] + #save + total_false_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_false_mcs_w_out_{model}.csv') + total_true_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels == model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save] + #add a column indicating if NLD is greater than Novelty Threshold + total_true_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_true_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_true_mc_predictions_df.index]['Novelty Threshold'] + #save + total_true_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_true_mcs_w_out_{model}.csv') + + print('Model: ', model) + print('num_outliers: ', num_outliers) + #print accuracy including outliers + print('mc_accuracy: ', model_prediction_pd['mc_accuracy'].mean()) + print('sc_accuracy: ', model_prediction_pd['sc_accuracy'].mean()) + + #print balanced accuracy + print('mc_balanced_accuracy: ', model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean()) + print('sc_balanced_accuracy: ', model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean()) + + #use plotly to plot confusion matrix based on mc classes + mc_confusion_matrix = model_prediction_pd.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack() + mc_confusion_matrix = mc_confusion_matrix.fillna(0) + mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1) + mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,4)) + #for columns not in rows, sum the column values and add them to a new column called 'other' + other_col = [0]*mc_confusion_matrix.shape[0] + for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]: + other_col += mc_confusion_matrix[i] + mc_confusion_matrix['other'] = other_col + #add an other row with all zeros + mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1] + #drop all columns not in rows + mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1) + #plot confusion matrix + + fig = go.Figure(data=go.Heatmap( + z=mc_confusion_matrix.values, + x=mc_confusion_matrix.columns, + y=mc_confusion_matrix.index, + colorscale='Blues', + hoverongaps = False)) + #add z values to heatmap + for i in range(len(mc_confusion_matrix.index)): + for j in range(len(mc_confusion_matrix.columns)): + fig.add_annotation(text=str(round(mc_confusion_matrix.values[i][j],2)), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i], + showarrow=False, font_size=font_size, font_color='black') + + fig.update_layout( + title='Confusion matrix for mc classes - ' + model + ' - ' + 'mc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean(),2)) \ + + ' - ' + 'sc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean(),2)) + '
' + \ + 'percent false novel: ' + str(round(num_outliers/num_rows,2)), + xaxis_nticks=36) + #label x axis and y axis + fig.update_xaxes(title_text='Predicted mc class') + fig.update_yaxes(title_text='Actual mc class') + if save_figs: + #save plot + if seperate_outliers: + fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.png') + #save svg + fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.svg') + else: + fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.png') + #save svg + fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.svg') + print('\n') + + +if __name__ == '__main__': + ##################################################################################################################### + mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json' + LC_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv' + path_to_models = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/' + + trained_on = 'full' #id or full + save_figs = True + + infer_aa = infer_relaxed_mirna = infer_hico = infer_ood = infer_other_affixes = infer_random = infer_fused = infer_na = infer_loco = False + + split = 'infer_hico'#sys.argv[1] + print(f'Running inference for {split}') + if split == 'infer_aa': + infer_aa = True + elif split == 'infer_relaxed_mirna': + infer_relaxed_mirna = True + elif split == 'infer_hico': + infer_hico = True + elif split == 'infer_ood': + infer_ood = True + elif split == 'infer_other_affixes': + infer_other_affixes = True + elif split == 'infer_random': + infer_random = True + elif split == 'infer_fused': + infer_fused = True + elif split == 'infer_na': + infer_na = True + elif split == 'infer_loco': + infer_loco = True + + ##################################################################################################################### + #only one of infer_aa or infer_relaxed_mirna or infer_normal or infer_ood or infer_hico should be true + if sum([infer_aa,infer_relaxed_mirna,infer_hico,infer_ood,infer_other_affixes,infer_random,infer_fused,infer_na,infer_loco]) != 1: + raise Exception('Only one of infer_aa or infer_relaxed_mirna or infer_normal or infer_ood or infer_hico or infer_other_affixes or infer_random or infer_fused or infer_na should be true') + + #set fig_prefix + if infer_aa: + fig_prefix = '5\'A-affixes' + elif infer_other_affixes: + fig_prefix = 'other_affixes' + elif infer_relaxed_mirna: + fig_prefix = 'Relaxed-miRNA' + elif infer_hico: + fig_prefix = 'LC-familiar' + elif infer_ood: + fig_prefix = 'LC-novel' + elif infer_random: + fig_prefix = 'Random' + elif infer_fused: + fig_prefix = 'Fused' + elif infer_na: + fig_prefix = 'NA' + elif infer_loco: + fig_prefix = 'LOCO' + + infer_df = load(LC_path) + if isinstance(infer_df,AnnData): + infer_df = infer_df.var + infer_df.set_index('sequence',inplace=True) + sc_to_mc_mapper_dict = load(mapping_dict_path) + #select hico sequences + hico_seqs = infer_df.index[infer_df['hico']].tolist() + art_affix_seqs = infer_df[~infer_df['five_prime_adapter_filter']].index.tolist() + + if infer_hico: + hico_seqs = hico_seqs + + if infer_aa: + hico_seqs = art_affix_seqs + + if infer_other_affixes: + hico_seqs = infer_df[~infer_df['hbdx_spikein_affix_filter']].index.tolist() + + if infer_na: + hico_seqs = infer_df[infer_df.subclass_name == 'no_annotation'].index.tolist() + + if infer_loco: + hico_seqs = infer_df[~infer_df['hico']][infer_df.subclass_name != 'no_annotation'].index.tolist() + + #for mirnas + if infer_relaxed_mirna: + #subclass name must contain miR, let, Let and not contain ; and that are not hico + mirnas_seqs = infer_df[infer_df.subclass_name.str.contains('miR') | infer_df.subclass_name.str.contains('let')][~infer_df.subclass_name.str.contains(';')].index.tolist() + #remove the ones that are true in ad.hico column + hico_seqs = list(set(mirnas_seqs).difference(set(hico_seqs))) + + #novel mirnas + #mirnas_not_in_train_mask = (ad['hico']==True).values * ~(ad['subclass_name'].isin(mirna_train_sc)).values * (ad['small_RNA_class_annotation'].isin(['miRNA'])) + #hicos = ad[mirnas_not_in_train_mask].index.tolist() + + + if infer_random: + #create random sequences + random_seqs = [] + while len(random_seqs) < 200: + random_seq = ''.join(random.choices(['A','C','G','T'], k=randint(18,30))) + if random_seq not in random_seqs: + random_seqs.append(random_seq) + hico_seqs = random_seqs + + if infer_fused: + hico_seqs = get_fused_seqs(hico_seqs,num_sequences=200) + + + #hico_seqs = ad[ad.subclass_name.str.contains('mir')][~ad.subclass_name.str.contains(';')]['subclass_name'].index.tolist() + hico_seqs = [seq for seq in hico_seqs if len(seq) <= 30] + #set cuda 1 + import os + os.environ["CUDA_VISIBLE_DEVICES"] = '1' + + #run prediction + prediction_pd = predict_transforna_all_models(hico_seqs,trained_on=trained_on,path_to_models=path_to_models) + prediction_pd['split'] = fig_prefix + #the if condition here is to make sure to filter seqs with sub classes not used in training + if not infer_ood and not infer_relaxed_mirna and not infer_hico: + prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv') + if infer_aa or infer_other_affixes or infer_random or infer_fused: + for model in prediction_pd.Model.unique(): + num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?']) + print(f'Number of non novel sequences for {model} is {num_non_novel}') + print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the lower the better') + + else: + if infer_na or infer_loco: + #print number of Is Familiar per model + for model in prediction_pd.Model.unique(): + num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?']) + print(f'Number of non novel sequences for {model} is {num_non_novel}') + print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the higher the better') + print('\n') + else: + #only to get classes used for training + prediction_single_pd = predict_transforna(hico_seqs[0],model='Seq',logits_flag = True,trained_on=trained_on,path_to_models=path_to_models) + sub_classes_used_for_training = prediction_single_pd.columns.tolist() + + + mc_classes_df,sc_classes_df = get_mc_sc(infer_df,hico_seqs,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag=infer_ood) + if infer_ood: + for model in prediction_pd.Model.unique(): + #filter sequences in prediction_pd to only include sequences in sc_classes_df + curr_prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)] + #filter curr_prediction toonly include model + curr_prediction_pd = curr_prediction_pd[curr_prediction_pd.Model == model] + num_seqs = curr_prediction_pd.shape[0] + #filter Is Familiar + curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Is Familiar?']] + #filter sc_classes_df to only include sequences in curr_prediction_pd + curr_sc_classes_df = sc_classes_df[sc_classes_df.index.isin(curr_prediction_pd['Sequence'].values)] + #correct labels and remove the correct labels from the curr_prediction_pd + curr_prediction_pd['Net-Label'] = correct_labels(curr_prediction_pd['Net-Label'].values,curr_sc_classes_df[0].values,sc_to_mc_mapper_dict) + #filter rows in curr_prediction where Labels is equal to sc_classes_df[0] + curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Net-Label'].values != curr_sc_classes_df[0].values] + num_non_novel = len(curr_prediction_pd) + print(f'Number of non novel sequences for {model} is {num_non_novel}') + print(f'Percent non novel for {model} is {num_non_novel/num_seqs}, the lower the better') + print('\n') + else: + #filter prediction_pd to include only sequences in prediction_pd + + #compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers=False,fig_prefix = fig_prefix,save_figs=save_figs) + compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers=True,fig_prefix = fig_prefix,save_figs=save_figs) + + if infer_ood or infer_relaxed_mirna or infer_hico: + prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)] + #save lev_dist_df + prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv') + + + + \ No newline at end of file diff --git a/transforna/src/__init__.py b/transforna/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..095ad93faeda5c87c6ff8d5fba16d67c13c82518 --- /dev/null +++ b/transforna/src/__init__.py @@ -0,0 +1,9 @@ +from .callbacks import * +from .callbacks.tbWriter import writer +from .inference import * +from .model import * +from .novelty_prediction import * +from .processing import * +from .score import * +from .train import * +from .utils import * diff --git a/transforna/src/callbacks/LRCallback.py b/transforna/src/callbacks/LRCallback.py new file mode 100644 index 0000000000000000000000000000000000000000..7edbef374d2e660104a99dec7fec21a346de0ad9 --- /dev/null +++ b/transforna/src/callbacks/LRCallback.py @@ -0,0 +1,174 @@ +import math +from collections.abc import Iterable +from math import cos, floor, log, pi + +import skorch +from torch.optim.lr_scheduler import _LRScheduler + +_LRScheduler + + +class CyclicCosineDecayLR(skorch.callbacks.Callback): + def __init__( + self, + optimizer, + init_interval, + min_lr, + len_param_groups, + base_lrs, + restart_multiplier=None, + restart_interval=None, + restart_lr=None, + last_epoch=-1, + ): + """ + Initialize new CyclicCosineDecayLR object + :param optimizer: (Optimizer) - Wrapped optimizer. + :param init_interval: (int) - Initial decay cycle interval. + :param min_lr: (float or iterable of floats) - Minimal learning rate. + :param restart_multiplier: (float) - Multiplication coefficient for increasing cycle intervals, + if this parameter is set, restart_interval must be None. + :param restart_interval: (int) - Restart interval for fixed cycle intervals, + if this parameter is set, restart_multiplier must be None. + :param restart_lr: (float or iterable of floats) - Optional, the learning rate at cycle restarts, + if not provided, initial learning rate will be used. + :param last_epoch: (int) - Last epoch. + """ + self.len_param_groups = len_param_groups + if restart_interval is not None and restart_multiplier is not None: + raise ValueError( + "You can either set restart_interval or restart_multiplier but not both" + ) + + if isinstance(min_lr, Iterable) and len(min_lr) != self.len_param_groups: + raise ValueError( + "Expected len(min_lr) to be equal to len(optimizer.param_groups), " + "got {} and {} instead".format(len(min_lr), self.len_param_groups) + ) + + if isinstance(restart_lr, Iterable) and len(restart_lr) != len( + self.len_param_groups + ): + raise ValueError( + "Expected len(restart_lr) to be equal to len(optimizer.param_groups), " + "got {} and {} instead".format(len(restart_lr), self.len_param_groups) + ) + + if init_interval <= 0: + raise ValueError( + "init_interval must be a positive number, got {} instead".format( + init_interval + ) + ) + + group_num = self.len_param_groups + self._init_interval = init_interval + self._min_lr = [min_lr] * group_num if isinstance(min_lr, float) else min_lr + self._restart_lr = ( + [restart_lr] * group_num if isinstance(restart_lr, float) else restart_lr + ) + self._restart_interval = restart_interval + self._restart_multiplier = restart_multiplier + self.last_epoch = last_epoch + self.base_lrs = base_lrs + super().__init__() + + def on_batch_end(self, net, training, **kwargs): + if self.last_epoch < self._init_interval: + return self._calc(self.last_epoch, self._init_interval, self.base_lrs) + + elif self._restart_interval is not None: + cycle_epoch = ( + self.last_epoch - self._init_interval + ) % self._restart_interval + lrs = self.base_lrs if self._restart_lr is None else self._restart_lr + return self._calc(cycle_epoch, self._restart_interval, lrs) + + elif self._restart_multiplier is not None: + n = self._get_n(self.last_epoch) + sn_prev = self._partial_sum(n) + cycle_epoch = self.last_epoch - sn_prev + interval = self._init_interval * self._restart_multiplier ** n + lrs = self.base_lrs if self._restart_lr is None else self._restart_lr + return self._calc(cycle_epoch, interval, lrs) + else: + return self._min_lr + + def _calc(self, t, T, lrs): + return [ + min_lr + (lr - min_lr) * (1 + cos(pi * t / T)) / 2 + for lr, min_lr in zip(lrs, self._min_lr) + ] + + def _get_n(self, epoch): + a = self._init_interval + r = self._restart_multiplier + _t = 1 - (1 - r) * epoch / a + return floor(log(_t, r)) + + def _partial_sum(self, n): + a = self._init_interval + r = self._restart_multiplier + return a * (1 - r ** n) / (1 - r) + + +class LearningRateDecayCallback(skorch.callbacks.Callback): + def __init__( + self, + config, + ): + super().__init__() + self.lr_warmup_end = config.lr_warmup_end + self.lr_warmup_start = config.lr_warmup_start + self.learning_rate = config.learning_rate + self.warmup_batch = config.warmup_epoch * config.batch_per_epoch + self.final_batch = config.final_epoch * config.batch_per_epoch + + self.batch_idx = 0 + + def on_batch_end(self, net, training, **kwargs): + """ + + :param trainer: + :type trainer: + :param pl_module: + :type pl_module: + :param batch: + :type batch: + :param batch_idx: + :type batch_idx: + :param dataloader_idx: + :type dataloader_idx: + """ + # to avoid updating after validation batch + if training: + + if self.batch_idx < self.warmup_batch: + # linear warmup, in paper: start from 0.1 to 1 over lr_warmup_end batches + lr_mult = float(self.batch_idx) / float(max(1, self.warmup_batch)) + lr = self.lr_warmup_start + lr_mult * ( + self.lr_warmup_end - self.lr_warmup_start + ) + else: + # Cosine learning rate decay + progress = float(self.batch_idx - self.warmup_batch) / float( + max(1, self.final_batch - self.warmup_batch) + ) + lr = max( + self.learning_rate + + 0.5 + * (1.0 + math.cos(math.pi * progress)) + * (self.lr_warmup_end - self.learning_rate), + self.learning_rate, + ) + net.lr = lr + # for param_group in net.optimizer.param_groups: + # param_group["lr"] = lr + + self.batch_idx += 1 + + +class LRAnnealing(skorch.callbacks.Callback): + def on_epoch_end(self, net, **kwargs): + if not net.history[-1]["valid_loss_best"]: + net.lr /= 4.0 diff --git a/transforna/src/callbacks/__init__.py b/transforna/src/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c4025cdd08a00bce4b5e77b8052a89f0054627 --- /dev/null +++ b/transforna/src/callbacks/__init__.py @@ -0,0 +1,4 @@ +from .criterion import * +from .LRCallback import * +from .metrics import * +from .tbWriter import * diff --git a/transforna/src/callbacks/criterion.py b/transforna/src/callbacks/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..d02bef66c03bef21ab7384726ba2e1263ea95969 --- /dev/null +++ b/transforna/src/callbacks/criterion.py @@ -0,0 +1,165 @@ +import copy +import random +from typing import List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LossFunction(nn.Module): + def __init__(self,main_config): + super(LossFunction, self).__init__() + self.train_config = main_config["train_config"] + self.model_config = main_config["model_config"] + self.batch_per_epoch = self.train_config.batch_per_epoch + self.warm_up_annealing = ( + self.train_config.warmup_epoch * self.batch_per_epoch + ) + self.num_embed_hidden = self.model_config.num_embed_hidden + self.batch_idx = 0 + self.loss_anealing_term = 0 + + + class_weights = self.model_config.class_weights + #TODO: use device as in main_config + class_weights = torch.FloatTensor([float(x) for x in class_weights]) + + if self.model_config.num_classes > 2: + self.clf_loss_fn = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=self.train_config.label_smoothing_clf,reduction='none') + else: + self.clf_loss_fn = self.focal_loss + + + # @staticmethod + def cosine_similarity_matrix( + self, gene_embedd: torch.Tensor, second_input_embedd: torch.Tensor, annealing=True + ) -> torch.Tensor: + # if annealing is true, then this function is being called from Net.predict and + # doesnt pass the instantiated object LossFunction, therefore no access to self. + # in Predict we also just need the max of predictions. + # for some reason, skorch only passes the LossFunction initialized object, only + # from get_loss fn. + + assert gene_embedd.size(0) == second_input_embedd.size(0) + + cosine_similarity = torch.matmul(gene_embedd, second_input_embedd.T) + + if annealing: + if self.batch_idx < self.warm_up_annealing: + self.loss_anealing_term = 1 + ( + self.batch_idx / self.warm_up_annealing + ) * torch.sqrt(torch.tensor(self.num_embed_hidden)) + + cosine_similarity *= self.loss_anealing_term + + return cosine_similarity + def get_similar_labels(self,y:torch.Tensor): + ''' + This function recieves y, the labels tensor + It creates a list of lists containing at every index a list(min_len = 2) of the indices of the labels that are similar + ''' + # create a test array + labels_y = y[:,0].cpu().detach().numpy() + + # creates an array of indices, sorted by unique element + idx_sort = np.argsort(labels_y) + + # sorts records array so all unique elements are together + sorted_records_array = labels_y[idx_sort] + + # returns the unique values, the index of the first occurrence of a value, and the count for each element + vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True) + + # splits the indices into separate arrays + res = np.split(idx_sort, idx_start[1:]) + #filter them with respect to their size, keeping only items occurring more than once + vals = vals[count > 1] + res = filter(lambda x: x.size > 1, res) + + indices_similar_labels = [] + similar_labels = [] + for r in res: + indices_similar_labels.append(list(r)) + similar_labels.append(list(labels_y[r])) + + return indices_similar_labels,similar_labels + + def get_triplet_samples(self,indices_similar_labels,similar_labels): + ''' + This function creates three lists, positives, anchors and negatives + Each index in the three lists correpond to a single triplet + ''' + positives,anchors,negatives = [],[],[] + for idx_similar_labels in indices_similar_labels: + random_indices = random.sample(range(len(idx_similar_labels)), 2) + positives.append(idx_similar_labels[random_indices[0]]) + anchors.append(idx_similar_labels[random_indices[1]]) + + negatives = copy.deepcopy(positives) + random.shuffle(negatives) + while (np.array(positives) == np.array(negatives)).any(): + random.shuffle(negatives) + + return positives,anchors,negatives + def get_triplet_loss(self,y,gene_embedd,second_input_embedd): + ''' + This function computes triplet loss by creating triplet samples of positives, negatives and anchors + The objective is to decrease the distance of the embeddings between the anchors and the positives + while increasing the distance between the anchor and the negatives. + This is done seperately for both the embeddings, gene embedds 0 and ss embedds 1 + ''' + #get similar labels + indices_similar_labels,similar_labels = self.get_similar_labels(y) + #insuring that there's at least two sets of labels in a given list (indices_similar_labels) + if len(indices_similar_labels)>1: + #get triplet samples + positives,anchors,negatives = self.get_triplet_samples(indices_similar_labels,similar_labels) + #get triplet loss for gene embedds + gene_embedd_triplet_loss = self.triplet_loss(gene_embedd[positives,:], + gene_embedd[anchors,:], + gene_embedd[negatives,:]) + #get triplet loss for ss embedds + second_input_embedd_triplet_loss = self.triplet_loss(second_input_embedd[positives,:], + second_input_embedd[anchors,:], + second_input_embedd[negatives,:]) + return gene_embedd_triplet_loss + second_input_embedd_triplet_loss + else: + return 0 + + def focal_loss(self,predicted_labels,y): + y = y.unsqueeze(dim=1) + y_new = torch.zeros(y.shape[0], 2).type(torch.cuda.FloatTensor) + y_new[range(y.shape[0]), y[:,0]]=1 + BCE_loss = F.binary_cross_entropy_with_logits(predicted_labels.float(), y_new.float(), reduction='none') + pt = torch.exp(-BCE_loss) # prevents nans when probability 0 + F_loss = (1-pt)**2 * BCE_loss + loss = 10*F_loss.mean() + return loss + + def contrastive_loss(self,cosine_similarity,batch_size): + j = -torch.sum(torch.diagonal(cosine_similarity)) + + cosine_similarity.diagonal().copy_(torch.zeros(cosine_similarity.size(0))) + + j = (1 - self.train_config.label_smoothing_sim) * j + ( + self.train_config.label_smoothing_sim / (cosine_similarity.size(0) * (cosine_similarity.size(0) - 1)) + ) * torch.sum(cosine_similarity) + + j += torch.sum(torch.logsumexp(cosine_similarity, dim=0)) + + if j < 0: + j = j-j + return j/batch_size + + def forward(self, embedds: List[torch.Tensor], y=None) -> torch.Tensor: + self.batch_idx += 1 + gene_embedd, second_input_embedd, predicted_labels,curr_epoch = embedds + + + loss = self.clf_loss_fn(predicted_labels,y.squeeze()) + + + + return loss diff --git a/transforna/src/callbacks/metrics.py b/transforna/src/callbacks/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcd6b4de5b48f22a9544f437ed106333317d661 --- /dev/null +++ b/transforna/src/callbacks/metrics.py @@ -0,0 +1,218 @@ +import os + +import numpy as np +import skorch +import torch +from sklearn.metrics import confusion_matrix, make_scorer +from skorch.callbacks import BatchScoring +from skorch.callbacks.scoring import ScoringBase, _cache_net_forward_iter +from skorch.callbacks.training import Checkpoint + +from .LRCallback import LearningRateDecayCallback + +writer = None + +def accuracy_score(y_true, y_pred: torch.tensor,task:str=None,mirna_flag:bool = False): + #sample + + # premirna + if task == "premirna": + y_pred = y_pred[:,:-1] + miRNA_idx = np.where(y_true.squeeze()==mirna_flag) + correct = torch.max(y_pred,1).indices.cpu().numpy()[miRNA_idx] == mirna_flag + return sum(correct) + + # sncrna + if task == "sncrna": + y_pred = y_pred[:,:-1] + # correct is of [samples], where each entry is true if it was found in top k + correct = torch.max(y_pred,1).indices.cpu().numpy() == y_true.squeeze() + + return sum(correct) / y_pred.shape[0] + + +def accuracy_score_tcga(y_true, y_pred): + + if torch.is_tensor(y_pred): + y_pred = y_pred.clone().detach().cpu().numpy() + if torch.is_tensor(y_true): + y_true = y_true.clone().detach().cpu().numpy() + + #y pred contains logits | samples weights + sample_weight = y_pred[:,-1] + y_pred = np.argmax(y_pred[:,:-1],axis=1) + + C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) + with np.errstate(divide='ignore', invalid='ignore'): + per_class = np.diag(C) / C.sum(axis=1) + if np.any(np.isnan(per_class)): + per_class = per_class[~np.isnan(per_class)] + score = np.mean(per_class) + return score + +def score_callbacks(cfg): + + acc_scorer = make_scorer(accuracy_score,task=cfg["task"]) + if cfg['task'] == 'tcga': + acc_scorer = make_scorer(accuracy_score_tcga) + + + if cfg["task"] == "premirna": + acc_scorer_mirna = make_scorer(accuracy_score,task=cfg["task"],mirna_flag = True) + + val_score_callback_mirna = BatchScoringPremirna( mirna_flag=True, + scoring = acc_scorer_mirna, lower_is_better=False, name="val_acc_mirna") + + train_score_callback_mirna = BatchScoringPremirna(mirna_flag=True, + scoring = acc_scorer_mirna, on_train=True, lower_is_better=False, name="train_acc_mirna") + + val_score_callback = BatchScoringPremirna(mirna_flag=False, + scoring = acc_scorer, lower_is_better=False, name="val_acc") + + train_score_callback = BatchScoringPremirna(mirna_flag=False, + scoring = acc_scorer, on_train=True, lower_is_better=False, name="train_acc") + + + scoring_callbacks = [ + train_score_callback, + train_score_callback_mirna + ] + if cfg["train_split"]: + scoring_callbacks.extend([val_score_callback_mirna,val_score_callback]) + + if cfg["task"] in ["sncrna", "tcga"]: + + val_score_callback = BatchScoring(acc_scorer, lower_is_better=False, name="val_acc") + train_score_callback = BatchScoring( + acc_scorer, on_train=True, lower_is_better=False, name="train_acc" + ) + scoring_callbacks = [train_score_callback] + + #tcga dataset has a predifined valid split, so train_split is false, but still valid metric is required + #TODO: remove predifined valid from tcga from prepare_data_tcga + if cfg["train_split"] or cfg['task'] == 'tcga': + scoring_callbacks.append(val_score_callback) + + return scoring_callbacks + +def get_callbacks(path,cfg): + + callback_list = [("lrcallback", LearningRateDecayCallback)] + if cfg['tensorboard'] == True: + from .tbWriter import writer + callback_list.append(MetricsVizualization) + + if (cfg["train_split"] or cfg['task'] == 'tcga') and cfg['inference'] == False: + monitor = "val_acc_best" + if cfg['trained_on'] == 'full': + monitor = 'train_acc_best' + ckpt_path = path+"/ckpt/" + try: + os.mkdir(ckpt_path) + except: + pass + model_name = f'model_params_{cfg["task"]}.pt' + callback_list.append(Checkpoint(monitor=monitor, dirname=ckpt_path,f_params=model_name)) + + scoring_callbacks = score_callbacks(cfg) + #TODO: For some reason scoring callbaks have to be inserted before checpoint and metrics viz callbacks + #otherwise NeuralNet notify function throws an exception + callback_list[1:1] = scoring_callbacks + + return callback_list + + +class MetricsVizualization(skorch.callbacks.Callback): + def __init__(self, batch_idx=0) -> None: + super().__init__() + self.batch_idx = batch_idx + + # TODO: Change to display metrics at epoch ends + def on_batch_end(self, net, training, **kwargs): + # validation batch + if not training: + # log val accuracy. accessing net.history:[ epoch ,batches, last batch,column in batch] + writer.add_scalar( + "Accuracy/val_acc", + net.history[-1, "batches", -1, "val_acc"], + self.batch_idx, + ) + # log val loss + writer.add_scalar( + "Loss/val_loss", + net.history[-1, "batches", -1, "valid_loss"], + self.batch_idx, + ) + # update batch idx after validation on batch is computed + # train batch + else: + # log lr + writer.add_scalar("Metrics/lr", net.lr, self.batch_idx) + # log train accuracy + writer.add_scalar( + "Accuracy/train_acc", + net.history[-1, "batches", -1, "train_acc"], + self.batch_idx, + ) + # log train loss + writer.add_scalar( + "Loss/train_loss", + net.history[-1, "batches", -1, "train_loss"], + self.batch_idx, + ) + self.batch_idx += 1 + +class BatchScoringPremirna(ScoringBase): + def __init__(self,mirna_flag:bool = False,*args,**kwargs): + super().__init__(*args,**kwargs) + #self.total_num_samples = total_num_samples + self.total_num_samples = 0 + self.mirna_flag = mirna_flag + self.first_batch_flag = True + def on_batch_end(self, net, X, y, training, **kwargs): + if training != self.on_train: + return + + y_preds = [kwargs['y_pred']] + #only for the first batch: get no. of samples belonging to same class samples + if self.first_batch_flag: + self.total_num_samples += sum(kwargs["batch"][1] == self.mirna_flag).detach().cpu().numpy()[0] + + with _cache_net_forward_iter(net, self.use_caching, y_preds) as cached_net: + # In case of y=None we will not have gathered any samples. + # We expect the scoring function to deal with y=None. + y = None if y is None else self.target_extractor(y) + try: + score = self._scoring(cached_net, X, y) + cached_net.history.record_batch(self.name_, score) + except KeyError: + pass + def get_avg_score(self, history): + if self.on_train: + bs_key = 'train_batch_size' + else: + bs_key = 'valid_batch_size' + + weights, scores = list(zip( + *history[-1, 'batches', :, [bs_key, self.name_]])) + #score_avg = np.average(scores, weights=weights) + score_avg = sum(scores)/self.total_num_samples + return score_avg + + # pylint: disable=unused-argument + def on_epoch_end(self, net, **kwargs): + self.first_batch_flag = False + history = net.history + try: # don't raise if there is no valid data + history[-1, 'batches', :, self.name_] + except KeyError: + return + + score_avg = self.get_avg_score(history) + is_best = self._is_best_score(score_avg) + if is_best: + self.best_score_ = score_avg + + history.record(self.name_, score_avg) + if is_best is not None: + history.record(self.name_ + '_best', bool(is_best)) diff --git a/transforna/src/callbacks/tbWriter.py b/transforna/src/callbacks/tbWriter.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe9d315e6e56da04c89e18f31f145ae7979ca7f --- /dev/null +++ b/transforna/src/callbacks/tbWriter.py @@ -0,0 +1,6 @@ + +from pathlib import Path + +from torch.utils.tensorboard import SummaryWriter + +writer = SummaryWriter(str(Path(__file__).parent.parent.parent.absolute())+"/runs/") diff --git a/transforna/src/inference/__init__.py b/transforna/src/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..356e0b5aea3778118766fdd245f2750a65a47f4b --- /dev/null +++ b/transforna/src/inference/__init__.py @@ -0,0 +1,3 @@ +from .inference_api import * +from .inference_benchmark import * +from .inference_tcga import * diff --git a/transforna/src/inference/inference_api.py b/transforna/src/inference/inference_api.py new file mode 100644 index 0000000000000000000000000000000000000000..5ebf7624b50d072579b609ead031ef4167e87404 --- /dev/null +++ b/transforna/src/inference/inference_api.py @@ -0,0 +1,243 @@ + +import logging +import warnings +from argparse import ArgumentParser +from contextlib import redirect_stdout +from datetime import datetime +from pathlib import Path +from typing import List + +import numpy as np +import pandas as pd +from hydra.utils import instantiate +from omegaconf import OmegaConf +from sklearn.preprocessing import StandardScaler +from umap import UMAP + +from ..novelty_prediction.id_vs_ood_nld_clf import get_closest_ngbr_per_split +from ..processing.seq_tokenizer import SeqTokenizer +from ..utils.file import load +from ..utils.tcga_post_analysis_utils import Results_Handler +from ..utils.utils import (get_model, infer_from_pd, + prepare_inference_results_tcga, + update_config_with_inference_params) + +logger = logging.getLogger(__name__) + +warnings.filterwarnings("ignore") + +def aggregate_ensemble_model(lev_dist_df:pd.DataFrame): + ''' + This function aggregates the predictions of the ensemble model by choosing the model with the lowest and the highest NLD per query sequence. + If the lowest NLD is lower than Novelty Threshold, then the model with the lowest NLD is chosen as the ensemble prediction. + Otherwise, the model with the highest NLD is chosen as the ensemble prediction. + ''' + #for every sequence, if at least one model scores an NLD < Novelty Threshold, then get the one with the least NLD as the ensemble prediction + #otherwise, get the highest NLD. + #get the minimum NLD per query sequence + #remove the baseline model + baseline_df = lev_dist_df[lev_dist_df['Model'] == 'Baseline'].reset_index(drop=True) + lev_dist_df = lev_dist_df[lev_dist_df['Model'] != 'Baseline'].reset_index(drop=True) + min_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmin().values] + #get the maximum NLD per query sequence + max_lev_dist_df = lev_dist_df.iloc[lev_dist_df.groupby('Sequence')['NLD'].idxmax().values] + #choose between each row in min_lev_dist_df and max_lev_dist_df based on the value of Novelty Threshold + novel_mask_df = min_lev_dist_df['NLD'] > min_lev_dist_df['Novelty Threshold'] + #get the rows where NLD is lower than Novelty Threshold + min_lev_dist_df = min_lev_dist_df[~novel_mask_df.values] + #get the rows where NLD is higher than Novelty Threshold + max_lev_dist_df = max_lev_dist_df[novel_mask_df.values] + #merge min_lev_dist_df and max_lev_dist_df + ensemble_lev_dist_df = pd.concat([min_lev_dist_df,max_lev_dist_df]) + #add ensemble model + ensemble_lev_dist_df['Model'] = 'Ensemble' + #add ensemble_lev_dist_df to lev_dist_df + lev_dist_df = pd.concat([lev_dist_df,ensemble_lev_dist_df,baseline_df]) + return lev_dist_df.reset_index(drop=True) + + +def read_inference_model_config(model:str,mc_or_sc,trained_on:str,path_to_models:str): + transforna_folder = "TransfoRNA_ID" + if trained_on == "full": + transforna_folder = "TransfoRNA_FULL" + + model_path = f"{path_to_models}/{transforna_folder}/{mc_or_sc}/{model}/meta/hp_settings.yaml" + cfg = OmegaConf.load(model_path) + return cfg + +def predict_transforna(sequences: List[str], model: str = "Seq-Rev", mc_or_sc:str='sub_class',\ + logits_flag:bool = False,attention_flag:bool = False,\ + similarity_flag:bool=False,n_sim:int=3,embedds_flag:bool = False, \ + umap_flag:bool = False,trained_on:str='full',path_to_models:str='') -> pd.DataFrame: + ''' + This function predicts the major class or sub class of a list of sequences using the TransfoRNA model. + Additionaly, it can return logits, attention scores, similarity scores, gene embeddings or umap embeddings. + + Input: + sequences: list of sequences to predict + model: model to use for prediction + mc_or_sc: models trained on major class or sub class + logits_flag: whether to return logits + attention_flag: whether to return attention scores (obtained from the self-attention layer) + similarity_flag: whether to return explanatory/similar sequences in the training set + n_sim: number of similar sequences to return + embedds_flag: whether to return embeddings of the sequences + umap_flag: whether to return umap embeddings + trained_on: whether to use the model trained on the full dataset or the ID dataset + Output: + pd.DataFrame with the predictions + ''' + #assers that only one flag is True + assert sum([logits_flag,attention_flag,similarity_flag,embedds_flag,umap_flag]) <= 1, 'One option at most can be True' + # capitalize the first letter of the model and the first letter after the - + model = "-".join([word.capitalize() for word in model.split("-")]) + cfg = read_inference_model_config(model,mc_or_sc,trained_on,path_to_models) + cfg = update_config_with_inference_params(cfg,mc_or_sc,trained_on,path_to_models) + root_dir = Path(__file__).parents[1].absolute() + + with redirect_stdout(None): + cfg, net = get_model(cfg, root_dir) + #original_infer_pd might include seqs that are longer than input model. if so, infer_pd contains the trimmed sequences + infer_pd = pd.Series(sequences, name="Sequences").to_frame() + predicted_labels, logits, gene_embedds_df,attn_scores_pd,all_data, max_len, net,_ = infer_from_pd(cfg, net, infer_pd, SeqTokenizer,attention_flag) + + if model == 'Seq': + gene_embedds_df = gene_embedds_df.iloc[:,:int(gene_embedds_df.shape[1]/2)] + if logits_flag: + cfg['log_logits'] = True + prepare_inference_results_tcga(cfg, predicted_labels, logits, all_data, max_len) + infer_pd = all_data["infere_rna_seq"] + + if logits_flag: + logits_df = infer_pd.rename_axis("Sequence").reset_index() + logits_cols = [col for col in infer_pd.columns if "Logits" in col] + logits_df = infer_pd[logits_cols] + logits_df.columns = pd.MultiIndex.from_tuples(logits_df.columns, names=["Logits", "Sub Class"]) + logits_df.columns = logits_df.columns.droplevel(0) + return logits_df + + elif attention_flag: + return attn_scores_pd + + elif embedds_flag: + return gene_embedds_df + + else: #return table with predictions, entropy, threshold, is familiar + #add aa predictions to infer_pd + embedds_path = '/'.join(cfg['inference_settings']["model_path"].split('/')[:-2])+'/embedds' + results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train']) + results.get_knn_model() + lv_threshold = load(results.analysis_path+"/novelty_model_coef")["Threshold"] + logger.info(f'computing levenstein distance for the inference set') + #prepare infer split + gene_embedds_df.columns = results.embedds_cols[:len(gene_embedds_df.columns)] + #add index of gene_embedds_df to be a column with name results.seq_col + gene_embedds_df[results.seq_col] = gene_embedds_df.index + #set gene_embedds_df as the new infer split + results.splits_df_dict['infer_df'] = gene_embedds_df + + + _,_,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,'infer',num_neighbors=n_sim) + + if similarity_flag: + #create df + sim_df = pd.DataFrame() + #populate query sequences and duplicate them n times + sequences = gene_embedds_df.index.tolist() + #duplicate each sequence n_sim times + sequences_duplicated = [seq for seq in sequences for _ in range(n_sim)] + sim_df['Sequence'] = sequences_duplicated + #assign top_5_seqs list to df column + sim_df[f'Explanatory Sequence'] = top_n_seqs + sim_df['NLD'] = lev_dist + sim_df['Explanatory Label'] = top_n_labels + sim_df['Novelty Threshold'] = lv_threshold + #for every query sequence, order the NLD in a increasing order + sim_df = sim_df.sort_values(by=['Sequence','NLD'],ascending=[False,True]) + return sim_df + + logger.info(f'num of hico based on entropy novelty prediction is {sum(infer_pd["Is Familiar?"])}') + #for every n_sim elements in the list, get the smallest levenstein distance + lv_dist_closest = [min(lev_dist[i:i+n_sim]) for i in range(0,len(lev_dist),n_sim)] + top_n_labels_closest = [top_n_labels[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)] + top_n_seqs_closest = [top_n_seqs[i:i+n_sim][np.argmin(lev_dist[i:i+n_sim])] for i in range(0,len(lev_dist),n_sim)] + infer_pd['Is Familiar?'] = [True if lv pd.DataFrame: + """ + Predicts the labels of the sequences using all the models available in the transforna package. + If non of the flags are true, it constructs and aggrgates the output of the ensemble model. + + Input: + sequences: list of sequences to predict + mc_or_sc: models trained on major class or sub class + logits_flag: whether to return logits + attention_flag: whether to return attention scores (obtained from the self-attention layer) + similarity_flag: whether to return explanatory/similar sequences in the training set + n_sim: number of similar sequences to return + embedds_flag: whether to return embeddings of the sequences + umap_flag: whether to return umap embeddings + trained_on: whether to use the model trained on the full dataset or the ID dataset + Output: + df: dataframe with the predictions + """ + now = datetime.now() + before_time = now.strftime("%H:%M:%S") + models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"] + if similarity_flag or embedds_flag: #remove baseline, takes long time + models = ["Baseline","Seq", "Seq-Seq", "Seq-Struct", "Seq-Rev"] + if attention_flag: #remove single based transformer models + models = ["Seq", "Seq-Struct", "Seq-Rev"] + df = None + for model in models: + logger.info(model) + df_ = predict_transforna(sequences, model, mc_or_sc,logits_flag,attention_flag,similarity_flag,n_sim,embedds_flag,umap_flag,trained_on=trained_on,path_to_models = path_to_models) + df_["Model"] = model + df = pd.concat([df, df_], axis=0) + #aggregate ensemble model if not of the flags are true + if not logits_flag and not attention_flag and not similarity_flag and not embedds_flag and not umap_flag: + df = aggregate_ensemble_model(df) + + now = datetime.now() + after_time = now.strftime("%H:%M:%S") + delta_time = datetime.strptime(after_time, "%H:%M:%S") - datetime.strptime(before_time, "%H:%M:%S") + logger.info(f"Time taken: {delta_time}") + + return df + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("sequences", nargs="+") + parser.add_argument("--logits_flag", nargs="?", const = True,default=False) + parser.add_argument("--attention_flag", nargs="?", const = True,default=False) + parser.add_argument("--similarity_flag", nargs="?", const = True,default=False) + parser.add_argument("--n_sim", nargs="?", const = 3,default=3) + parser.add_argument("--embedds_flag", nargs="?", const = True,default=False) + parser.add_argument("--trained_on", nargs="?", const = True,default="full") + predict_transforna_all_models(**vars(parser.parse_args())) diff --git a/transforna/src/inference/inference_benchmark.py b/transforna/src/inference/inference_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..85a7d40665b9cc69c9500413375fa2b395a06606 --- /dev/null +++ b/transforna/src/inference/inference_benchmark.py @@ -0,0 +1,48 @@ + +from pathlib import Path +from typing import Dict + +from ..callbacks.metrics import accuracy_score +from ..processing.seq_tokenizer import SeqTokenizer +from ..score.score import infer_from_model, infer_testset +from ..utils.file import load, save +from ..utils.utils import * + + +def infer_benchmark(cfg:Dict= None,path:str = None): + if cfg['tensorboard']: + from ..callbacks.tbWriter import writer + + model = cfg["model_name"]+'_'+cfg['task'] + + #set seed + set_seed_and_device(cfg["seed"],cfg["device_number"]) + #get data + ad = load(cfg["train_config"].dataset_path_train) + + #instantiate dataset class + dataset_class = SeqTokenizer(ad.var,cfg) + test_data = load(cfg["train_config"].dataset_path_test) + #prepare data for training and inference + all_data = prepare_data_benchmark(dataset_class,test_data,cfg) + + + + #sync skorch config with params in train and model config + sync_skorch_with_config(cfg["model"]["skorch_model"],cfg) + + # instantiate skorch model + net = instantiate_predictor(cfg["model"]["skorch_model"], cfg,path) + net.initialize() + net.load_params(f_params=f'{cfg["inference_settings"]["model_path"]}') + + #perform inference on task specific testset + if cfg["inference_settings"]["infere_original_testset"]: + infer_testset(net,cfg,all_data,accuracy_score) + + #inference on custom data + predicted_labels,logits,_,_ = infer_from_model(net,all_data["infere_data"]) + prepare_inference_results_benchmarck(net,cfg,predicted_labels,logits,all_data) + save(path=Path(__file__).parent.parent.absolute() / f'inference_results_{model}',data=all_data["infere_rna_seq"]) + if cfg['tensorboard']: + writer.close() \ No newline at end of file diff --git a/transforna/src/inference/inference_tcga.py b/transforna/src/inference/inference_tcga.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8af085ef465792918166618dd35ae8b41670e6 --- /dev/null +++ b/transforna/src/inference/inference_tcga.py @@ -0,0 +1,38 @@ + +from typing import Dict + +from anndata import AnnData + +from ..processing.seq_tokenizer import SeqTokenizer +from ..utils.file import load +from ..utils.utils import * + + +def infer_tcga(cfg:Dict= None,path:str = None): + if cfg['tensorboard']: + from ..callbacks.tbWriter import writer + cfg,net = get_model(cfg,path) + inference_path = cfg['inference_settings']['sequences_path'] + original_infer_df = load(inference_path, index_col=0) + if isinstance(original_infer_df,AnnData): + original_infer_df = original_infer_df.var + predicted_labels,logits,_,_,all_data,max_len,net,infer_df = infer_from_pd(cfg,net,original_infer_df,SeqTokenizer) + + #create inference_output if it does not exist + if not os.path.exists(f"inference_output"): + os.makedirs(f"inference_output") + if cfg['log_embedds']: + embedds_pd = log_embedds(cfg,net,all_data['infere_rna_seq']) + embedds_pd.to_csv(f"inference_output/{cfg['model_name']}_embedds.csv") + + prepare_inference_results_tcga(cfg,predicted_labels,logits,all_data,max_len) + + #if sequences were trimmed, add mapping of trimmed sequences to original sequences + if original_infer_df.shape[0] != infer_df.shape[0]: + all_data["infere_rna_seq"] = add_original_seqs_to_predictions(infer_df,all_data['infere_rna_seq']) + #save + all_data["infere_rna_seq"].to_csv(f"inference_output/{cfg['model_name']}_inference_results.csv") + + if cfg['tensorboard']: + writer.close() + return predicted_labels \ No newline at end of file diff --git a/transforna/src/model/__init__.py b/transforna/src/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d63a2cc531569bee4b83d5d043b58f8c1f20ba --- /dev/null +++ b/transforna/src/model/__init__.py @@ -0,0 +1,2 @@ +from .model_components import * +from .skorchWrapper import * diff --git a/transforna/src/model/model_components.py b/transforna/src/model/model_components.py new file mode 100644 index 0000000000000000000000000000000000000000..020206756da5ca20b92e584d3c4666f8c152cf85 --- /dev/null +++ b/transforna/src/model/model_components.py @@ -0,0 +1,449 @@ + +import logging +import math +import random +from typing import Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import DictConfig +from torch.nn.modules.normalization import LayerNorm + +logger = logging.getLogger(__name__) + +def circulant_mask(n: int, window: int) -> torch.Tensor: + """Calculate the relative attention mask, calculated once when model instatiated, as a subset of this matrix + will be used for a input length less than max. + i,j represent relative token positions in this matrix and in the attention scores matrix, + this mask enables attention scores to be set to 0 if further than the specified window length + + :param n: a fixed parameter set to be larger than largest max sequence length across batches + :param window: [window length], + :return relative attention mask + """ + circulant_t = torch.zeros(n, n) + # [0, 1, 2, ..., window, -1, -2, ..., window] + offsets = [0] + [i for i in range(window + 1)] + [-i for i in range(window + 1)] + if window >= n: + return torch.ones(n, n) + for offset in offsets: + # size of the 1-tensor depends on the length of the diagonal + circulant_t.diagonal(offset=offset).copy_(torch.ones(n - abs(offset))) + return circulant_t + + +class SelfAttention(nn.Module): + + """normal query, key, value based self attention but with relative attention functionality + and a learnable bias encoding relative token position which is added to the attention scores before the softmax""" + + def __init__(self, config: DictConfig, relative_attention: int): + """init self attention weight of each key, query, value and output projection layer. + + :param config: model config + :type config: ConveRTModelConfig + """ + super().__init__() + + self.config = config + self.query = nn.Linear(config.num_embed_hidden, config.num_attention_project) + self.key = nn.Linear(config.num_embed_hidden, config.num_attention_project) + self.value = nn.Linear(config.num_embed_hidden, config.num_attention_project) + + self.softmax = nn.Softmax(dim=-1) + self.output_projection = nn.Linear( + config.num_attention_project, config.num_embed_hidden + ) + self.bias = torch.nn.Parameter(torch.randn(config.n), requires_grad=True) + stdv = 1.0 / math.sqrt(self.bias.data.size(0)) + self.bias.data.uniform_(-stdv, stdv) + self.relative_attention = relative_attention + self.n = self.config.n + self.half_n = self.n // 2 + self.register_buffer( + "relative_mask", + circulant_mask(config.tokens_len, self.relative_attention), + ) + + def forward( + self, attn_input: torch.Tensor, attention_mask: torch.Tensor + ) -> torch.Tensor: + """calculate self-attention of query, key and weighted to value at the end. + self-attention input is projected by linear layer at the first time. + applying attention mask for ignore pad index attention weight. Relative attention mask + applied and a learnable bias added to the attention scores. + return value after apply output projection layer to value * attention + + :param attn_input: [description] + :type attn_input: [type] + :param attention_mask: [description], defaults to None + :type attention_mask: [type], optional + :return: [description] + :rtype: [type] + """ + self.T = attn_input.size()[1] + # input is B x max seq len x n_emb + _query = self.query.forward(attn_input) + _key = self.key.forward(attn_input) + _value = self.value.forward(attn_input) + + # scaled dot product + attention_scores = torch.matmul(_query, _key.transpose(1, 2)) + attention_scores = attention_scores / math.sqrt( + self.config.num_attention_project + ) + + # Relative attention + + # extended_attention_mask = attention_mask.to(attention_scores.device) # fp16 compatibility + extended_attention_mask = (1.0 - attention_mask.unsqueeze(-1)) * -10000.0 + attention_scores = attention_scores + extended_attention_mask + + # fix circulant_matrix to matrix of size 60 x60 (max token truncation_length, + # register as buffer, so not keep creating masks of different sizes. + + attention_scores = attention_scores.masked_fill( + self.relative_mask.unsqueeze(0)[:, : self.T, : self.T] == 0, float("-inf") + ) + + # Learnable bias vector is used of max size,for each i, different subsets of it are added to the scores, where the permutations + # depend on the relative position (i-j). this way cleverly allows no loops. bias vector is 2*max truncation length+1 + # so has a learnable parameter for each eg. (i-j) /in {-60,...60} . + + ii, jj = torch.meshgrid(torch.arange(self.T), torch.arange(self.T)) + B_matrix = self.bias[self.n // 2 - ii + jj] + + attention_scores = attention_scores + B_matrix.unsqueeze(0) + + attention_scores = self.softmax(attention_scores) + output = torch.matmul(attention_scores, _value) + + output = self.output_projection(output) + + return [output,attention_scores] # B x T x num embed hidden + + + +class FeedForward1(nn.Module): + def __init__( + self, input_hidden: int, intermediate_hidden: int, dropout_rate: float = 0.0 + ): + # 512 2048 + + super().__init__() + + self.linear_1 = nn.Linear(input_hidden, intermediate_hidden) + self.dropout = nn.Dropout(dropout_rate) + self.linear_2 = nn.Linear(intermediate_hidden, input_hidden) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + x = F.gelu(self.linear_1(x)) + return self.linear_2(self.dropout(x)) + + +class SharedInnerBlock(nn.Module): + def __init__(self, config: DictConfig, relative_attn: int): + super().__init__() + + self.config = config + self.self_attention = SelfAttention(config, relative_attn) + self.norm1 = LayerNorm(config.num_embed_hidden) # 512 + self.dropout = nn.Dropout(config.dropout) + self.ff1 = FeedForward1( + config.num_embed_hidden, config.feed_forward1_hidden, config.dropout + ) + self.norm2 = LayerNorm(config.num_embed_hidden) + + def forward(self, x: torch.Tensor, attention_mask: int) -> torch.Tensor: + + new_values_x,attn_scores = self.self_attention(x, attention_mask=attention_mask) + x = x+new_values_x + x = self.norm1(x) + x = x + self.ff1(x) + return self.norm2(x),attn_scores + + +# pretty basic, just single head. but done many times, stack to have another dimension (4 with batches).# so get stacks of B x H of attention scores T x T.. +# then matrix multiply these extra stacks with the v +# (B xnh)x T xT . (Bx nh xTx hs) gives (B Nh) T x hs stacks. now hs is set to be final dimension/ number of heads, so reorder the stacks (concatenating them) +# can have optional extra projection layer, but doing that later + + +class MultiheadAttention(nn.Module): + def __init__(self, config: DictConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.num_attn_proj = config.num_embed_hidden * config.num_attention_heads + self.attention_head_size = int(self.num_attn_proj / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.num_embed_hidden, self.num_attn_proj) + self.key = nn.Linear(config.num_embed_hidden, self.num_attn_proj) + self.value = nn.Linear(config.num_embed_hidden, self.num_attn_proj) + + self.dropout = nn.Dropout(config.dropout) + + def forward( + self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + B, T, _ = hidden_states.size() + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = ( + self.key(hidden_states) + .view(B, T, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # (B, nh, T, hs) + q = ( + self.query(hidden_states) + .view(B, T, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # (B, nh, T, hs) + v = ( + self.value(hidden_states) + .view(B, T, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) # (B, nh, T, hs) + + attention_scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + + if attention_mask is not None: + attention_mask = attention_mask[:, None, None, :] + attention_mask = (1.0 - attention_mask) * -10000.0 + + attention_scores = attention_scores + attention_mask + + attention_scores = F.softmax(attention_scores, dim=-1) + + attention_scores = self.dropout(attention_scores) + + y = attention_scores @ v + + y = y.transpose(1, 2).contiguous().view(B, T, self.num_attn_proj) + + return y + + +class PositionalEncoding(nn.Module): + def __init__(self, model_config: DictConfig,): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=model_config.dropout) + self.num_embed_hidden = model_config.num_embed_hidden + pe = torch.zeros(model_config.tokens_len, self.num_embed_hidden) + position = torch.arange( + 0, model_config.tokens_len, dtype=torch.float + ).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.num_embed_hidden, 2).float() + * (-math.log(10000.0) / self.num_embed_hidden) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[: x.size(0), :] + return self.dropout(x) + + +class RNAFFrwd( + nn.Module +): # params are not shared for context and reply. so need two sets of weights + """Fully-Connected 3-layer Linear Model""" + + def __init__(self, model_config: DictConfig): + """ + :param input_hidden: first-hidden layer input embed-dim + :type input_hidden: int + :param intermediate_hidden: layer-(hidden)-layer middle point weight + :type intermediate_hidden: int + :param dropout_rate: dropout rate, defaults to None + :type dropout_rate: float, optional + """ + # paper specifies,skip connections,layer normalization, and orthogonal initialization + + super().__init__() + # 3,679,744 x2 params + self.rna_ffwd_input_dim = ( + model_config.num_embed_hidden * model_config.num_attention_heads + ) + self.linear_1 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim) + self.linear_2 = nn.Linear(self.rna_ffwd_input_dim, self.rna_ffwd_input_dim) + + self.norm1 = LayerNorm(self.rna_ffwd_input_dim) + self.norm2 = LayerNorm(self.rna_ffwd_input_dim) + self.final = nn.Linear(self.rna_ffwd_input_dim, model_config.num_embed_hidden) + self.orthogonal_initialization() # torch implementation works perfectly out the box, + + def orthogonal_initialization(self): + for l in [ + self.linear_1, + self.linear_2, + ]: + torch.nn.init.orthogonal_(l.weight) + + def forward(self, x: torch.Tensor, attn_msk: torch.Tensor) -> torch.Tensor: + sentence_lengths = attn_msk.sum(1) + + # adding square root reduction projection separately as not a shared. + # part of the diagram torch.Size([Batch, scent_len, embedd_dim]) + + # x has dims B x T x 2*d_emb + norms = 1 / torch.sqrt(sentence_lengths.double()).float() # 64 + # TODO: Aggregation is done on all words including the masked ones + x = norms.unsqueeze(1) * torch.sum(x, dim=1) # 64 x1024 + + x = x + F.gelu(self.linear_1(self.norm1(x))) + x = x + F.gelu(self.linear_2(self.norm2(x))) + + return F.normalize(self.final(x), dim=1, p=2) # 64 512 + + +class RNATransformer(nn.Module): + def __init__(self, model_config: DictConfig): + super().__init__() + self.num_embedd_hidden = model_config.num_embed_hidden + self.encoder = nn.Embedding( + model_config.vocab_size, model_config.num_embed_hidden + ) + self.model_input = model_config.model_input + if 'baseline' not in self.model_input: + # positional encoder + self.pos_encoder = PositionalEncoding(model_config) + + self.transformer_layers = nn.ModuleList( + [ + SharedInnerBlock(model_config, int(window/model_config.window)) + for window in model_config.relative_attns[ + : model_config.num_encoder_layers + ] + ] + ) + self.MHA = MultiheadAttention(model_config) + # self.concatenate = FeedForward2(model_config) + + self.rna_ffrwd = RNAFFrwd(model_config) + self.pad_id = 0 + + def forward(self, x:torch.Tensor) -> torch.Tensor: + if x.is_cuda: + long_tensor = torch.cuda.LongTensor + else: + long_tensor = torch.LongTensor + + embedds = self.encoder(x) + if 'baseline' not in self.model_input: + output = self.pos_encoder(embedds) + attention_mask = (x != self.pad_id).int() + + for l in self.transformer_layers: + output,attn_scores = l(output, attention_mask) + output = self.MHA(output) + output = self.rna_ffrwd(output, attention_mask) + return output,attn_scores + else: + embedds = torch.flatten(embedds,start_dim=1) + return embedds,None + +class GeneEmbeddModel(nn.Module): + def __init__( + self, main_config: DictConfig, + ): + super().__init__() + self.train_config = main_config["train_config"] + self.model_config = main_config["model_config"] + self.device = self.train_config.device + self.model_input = self.model_config["model_input"] + self.false_input_perc = self.model_config["false_input_perc"] + #adjust n (used to add rel bias on attn scores) + self.model_config.n = self.model_config.tokens_len*2+1 + self.transformer_layers = RNATransformer(self.model_config) + #save tokens_len of sequences to be used to split ids between transformers + self.tokens_len = self.model_config.tokens_len + #reassign tokens_len and vocab_size to init a new transformer + #more clean solution -> RNATransformer and its children should + # have a flag input indicating which transformer + self.model_config.tokens_len = self.model_config.second_input_token_len + self.model_config.n = self.model_config.tokens_len*2+1 + self.seq_vocab_size = self.model_config.vocab_size + #this differs between both models not the token_len/ss_token_len + self.model_config.vocab_size = self.model_config.second_input_vocab_size + + self.second_input_model = RNATransformer(self.model_config) + + #num_transformers refers to using either one model or two in parallel + self.num_transformers = 2 + if self.model_input == 'seq': + self.num_transformers = 1 + # could be moved to model + self.weight_decay = self.train_config.l2_weight_decay + if 'baseline' in self.model_input: + self.num_transformers = 1 + num_nodes = self.model_config.num_embed_hidden*self.tokens_len + self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes) + else: + #setting classification layer + num_nodes = self.num_transformers*self.model_config.num_embed_hidden + if self.num_transformers == 1: + self.final_clf_1 = nn.Linear(num_nodes,self.model_config.num_classes) + else: + self.final_clf_1 = nn.Linear(num_nodes,num_nodes) + self.final_clf_2 = nn.Linear(num_nodes,self.model_config.num_classes) + self.relu = nn.ReLU() + self.BN = nn.BatchNorm1d(num_nodes) + self.dropout = nn.Dropout(0.6) + + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + + def distort_input(self,x): + for sample_idx in range(x.shape[0]): + seq_length = x[sample_idx,-1] + num_tokens_flipped = int(self.false_input_perc*seq_length) + max_start_flip_idx = seq_length - num_tokens_flipped + + random_feat_idx = random.randint(0,max_start_flip_idx-1) + x[sample_idx,random_feat_idx:random_feat_idx+num_tokens_flipped] = \ + torch.tensor(np.random.choice(range(1,self.seq_vocab_size-1),size=num_tokens_flipped,replace=True)) + + x[sample_idx,random_feat_idx+self.tokens_len:random_feat_idx+self.tokens_len+num_tokens_flipped] = \ + torch.tensor(np.random.choice(range(1,self.model_config.second_input_vocab_size-1),size=num_tokens_flipped,replace=True)) + return x + + def forward(self, x,train=False): + if self.device == 'cuda': + long_tensor = torch.cuda.LongTensor + float_tensor = torch.cuda.FloatTensor + else: + long_tensor = torch.LongTensor + float_tensor = torch.FloatTensor + if train: + if self.false_input_perc > 0: + x = self.distort_input(x) + + gene_embedd,attn_scores_first = self.transformer_layers( + x[:, : self.tokens_len].type(long_tensor) + ) + attn_scores_second = None + second_input_embedd,attn_scores_second = self.second_input_model( + x[:, self.tokens_len :-1].type(long_tensor) + ) + + #for tcga: if seq or baseline + if self.num_transformers == 1: + activations = self.final_clf_1(gene_embedd) + else: + out_clf_1 = self.final_clf_1(torch.cat((gene_embedd, second_input_embedd), 1)) + out = self.BN(out_clf_1) + out = self.relu(out) + out = self.dropout(out) + activations = self.final_clf_2(out) + + #create dummy attn scores for baseline + if 'baseline' in self.model_input: + attn_scores_first = torch.ones((1,2,2),device=x.device) + + return [gene_embedd, second_input_embedd, activations,attn_scores_first,attn_scores_second] diff --git a/transforna/src/model/skorchWrapper.py b/transforna/src/model/skorchWrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..126249945cfd5ce0a63aa6329d9c7f59e1b3cb63 --- /dev/null +++ b/transforna/src/model/skorchWrapper.py @@ -0,0 +1,364 @@ +import logging +import os +import pickle + +import skorch +import torch +from skorch.dataset import Dataset, ValidSplit +from skorch.setter import optimizer_setter +from skorch.utils import is_dataset, to_device + +logger = logging.getLogger(__name__) +#from ..tbWriter import writer + + +class Net(skorch.NeuralNet): + def __init__( + self, + clip=0.25, + top_k=1, + correct=0, + save_embedding=False, + gene_embedds=[], + second_input_embedd=[], + confidence_threshold = 0.95, + *args, + **kwargs + ): + self.clip = clip + self.curr_epoch = 0 + super(Net, self).__init__(*args, **kwargs) + self.correct = correct + self.save_embedding = save_embedding + self.gene_embedds = gene_embedds + self.second_input_embedds = second_input_embedd + self.main_config = kwargs["module__main_config"] + self.train_config = self.main_config["train_config"] + self.top_k = self.train_config.top_k + self.num_classes = self.main_config["model_config"].num_classes + self.labels_mapping_path = self.train_config.labels_mapping_path + if self.labels_mapping_path: + with open(self.labels_mapping_path, 'rb') as handle: + self.labels_mapping_dict = pickle.load(handle) + self.confidence_threshold = confidence_threshold + self.max_epochs = kwargs["max_epochs"] + self.task = '' #is set in utils.instantiate_predictor + self.log_tb = False + + + + + def set_save_epoch(self): + ''' + scale best train epoch by valid size + ''' + if self.task !='tcga': + if self.train_split: + self.save_epoch = self.main_config["train_config"].train_epoch + else: + self.save_epoch = round(self.main_config["train_config"].train_epoch*\ + (1+self.main_config["valid_size"])) + + def save_benchmark_model(self): + ''' + saves benchmark epochs when train_split is none + ''' + try: + os.mkdir("ckpt") + except: + pass + cwd = os.getcwd()+"/ckpt/" + self.save_params(f_params= f'{cwd}/model_params_{self.main_config["task"]}.pt') + + + def fit(self, X, y=None, valid_ds=None,**fit_params): + #all sequence lengths should be saved to compute the median based + self.all_lengths = [[] for i in range(self.num_classes)] + self.median_lengths = [] + + if not self.warm_start or not self.initialized_: + self.initialize() + + if valid_ds: + self.validation_dataset = valid_ds + else: + self.validation_dataset = None + + self.partial_fit(X, y, **fit_params) + return self + + def fit_loop(self, X, y=None, epochs=None, **fit_params): + #if id then train longer otherwise stop at 0.99 + rounding_digits = 3 + if self.main_config['trained_on'] == 'full': + rounding_digits = 2 + self.check_data(X, y) + epochs = epochs if epochs is not None else self.max_epochs + + dataset_train, dataset_valid = self.get_split_datasets(X, y, **fit_params) + + if self.validation_dataset is not None: + dataset_valid = self.validation_dataset.keywords["valid_ds"] + + on_epoch_kwargs = { + "dataset_train": dataset_train, + "dataset_valid": dataset_valid, + } + + iterator_train = self.get_iterator(dataset_train, training=True) + iterator_valid = None + if dataset_valid is not None: + iterator_valid = self.get_iterator(dataset_valid, training=False) + + self.set_save_epoch() + + for epoch_no in range(epochs): + #save model if training only on test set + self.curr_epoch = epoch_no + #save epoch is scaled by best train epoch + #save benchmark only when training on boith train and val sets + if self.task != 'tcga' and epoch_no == self.save_epoch and self.train_split == None: + self.save_benchmark_model() + + self.notify("on_epoch_begin", **on_epoch_kwargs) + + self.run_single_epoch( + iterator_train, + training=True, + prefix="train", + step_fn=self.train_step, + **fit_params + ) + + if dataset_valid is not None: + self.run_single_epoch( + iterator_valid, + training=False, + prefix="valid", + step_fn=self.validation_step, + **fit_params + ) + + + self.notify("on_epoch_end", **on_epoch_kwargs) + #manual early stopping for tcga + if self.task == 'tcga': + train_acc = round(self.history[:,'train_acc'][-1],rounding_digits) + if train_acc == 1: + break + + + + return self + + def train_step(self, X, y=None): + y = X[1] + X = X[0] + sample_weights = X[:,-1] + if self.device == 'cuda': + sample_weights = sample_weights.to(self.train_config.device) + self.module_.train() + self.module_.zero_grad() + gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1],train=True) + #curr_epoch is passed to loss as it is used to switch loss criteria from unsup. -> sup + loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y) + + ###sup loss should be X with samples weight and aggregated + + loss = loss*sample_weights + loss = loss.mean() + + loss.backward() + + # TODO: clip only some parameters + torch.nn.utils.clip_grad_norm_(self.module_.parameters(), self.clip) + self.optimizer_.step() + + return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]} + + def validation_step(self, X, y=None): + y = X[1] + X = X[0] + sample_weights = X[:,-1] + if self.device == 'cuda': + sample_weights = sample_weights.to(self.train_config.device) + self.module_.eval() + with torch.no_grad(): + gene_embedd, second_input_embedd, activations,_,_ = self.module_(X[:,:-1]) + loss = self.get_loss([gene_embedd,second_input_embedd,activations,self.curr_epoch], y) + + ###sup loss should be X with samples weight and aggregated + + loss = loss*sample_weights + loss = loss.mean() + + return {"X":X,"y":y,"loss": loss, "y_pred": [gene_embedd,second_input_embedd,activations]} + + def get_attention_scores(self, X, y=None): + ''' + returns attention scores for a given input + ''' + self.module_.eval() + with torch.no_grad(): + _, _, _,attn_scores_first,attn_scores_second = self.module_(X[:,:-1]) + + attn_scores_first = attn_scores_first.detach().cpu().numpy() + if attn_scores_second is not None: + attn_scores_second = attn_scores_second.detach().cpu().numpy() + return attn_scores_first,attn_scores_second + + def predict(self, X): + self.module_.train(False) + embedds = self.module_(X[:,:-1]) + sample_weights = X[:,-1] + if self.device == 'cuda': + sample_weights = sample_weights.to(self.train_config.device) + + gene_embedd, second_input_embedd, activations,_,_ = embedds + if self.save_embedding: + self.gene_embedds.append(gene_embedd.detach().cpu()) + #in case only a single transformer is deployed, then second_input_embedd are None. thus have no detach() + if second_input_embedd is not None: + self.second_input_embedds.append(second_input_embedd.detach().cpu()) + + predictions = torch.cat([activations,sample_weights[:,None]],dim=1) + return predictions + + + def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs): + # log gradients and weights + for _, m in self.module_.named_modules(): + for pn, p in m.named_parameters(): + if pn.endswith("weight") and pn.find("norm") < 0: + if p.grad != None: + if self.log_tb: + from ..callbacks.tbWriter import writer + writer.add_histogram("weights/" + pn, p, len(net.history)) + writer.add_histogram( + "gradients/" + pn, p.grad.data, len(net.history) + ) + + return + + def configure_opt(self, l2_weight_decay): + no_decay = ["bias", "LayerNorm.weight"] + params_decay = [ + p + for n, p in self.module_.named_parameters() + if not any(nd in n for nd in no_decay) + ] + params_nodecay = [ + p + for n, p in self.module_.named_parameters() + if any(nd in n for nd in no_decay) + ] + optim_groups = [ + {"params": params_decay, "weight_decay": l2_weight_decay}, + {"params": params_nodecay, "weight_decay": 0.0}, + ] + return optim_groups + + def initialize_optimizer(self, triggered_directly=True): + """Initialize the model optimizer. If ``self.optimizer__lr`` + is not set, use ``self.lr`` instead. + + Parameters + ---------- + triggered_directly : bool (default=True) + Only relevant when optimizer is re-initialized. + Initialization of the optimizer can be triggered directly + (e.g. when lr was changed) or indirectly (e.g. when the + module was re-initialized). If and only if the former + happens, the user should receive a message informing them + about the parameters that caused the re-initialization. + + """ + # get learning rate from train config + optimizer_params = self.main_config["train_config"] + kwargs = {} + kwargs["lr"] = optimizer_params.learning_rate + # get l2 weight decay to init opt params + args = self.configure_opt(optimizer_params.l2_weight_decay) + + if self.initialized_ and self.verbose: + msg = self._format_reinit_msg( + "optimizer", kwargs, triggered_directly=triggered_directly + ) + logger.info(msg) + + self.optimizer_ = self.optimizer(args, lr=kwargs["lr"]) + + self._register_virtual_param( + ["optimizer__param_groups__*__*", "optimizer__*", "lr"], + optimizer_setter, + ) + + def initialize_criterion(self): + """Initializes the criterion.""" + # critereon takes train_config and model_config as an input. + # we get both from the module parameters + self.criterion_ = self.criterion( + self.main_config + ) + if isinstance(self.criterion_, torch.nn.Module): + self.criterion_ = to_device(self.criterion_, self.device) + return self + + def initialize_callbacks(self): + """Initializes all callbacks and save the result in the + ``callbacks_`` attribute. + + Both ``default_callbacks`` and ``callbacks`` are used (in that + order). Callbacks may either be initialized or not, and if + they don't have a name, the name is inferred from the class + name. The ``initialize`` method is called on all callbacks. + + The final result will be a list of tuples, where each tuple + consists of a name and an initialized callback. If names are + not unique, a ValueError is raised. + + """ + if self.callbacks == "disable": + self.callbacks_ = [] + return self + + callbacks_ = [] + + class Dummy: + # We cannot use None as dummy value since None is a + # legitimate value to be set. + pass + + for name, cb in self._uniquely_named_callbacks(): + # check if callback itself is changed + param_callback = getattr(self, "callbacks__" + name, Dummy) + if param_callback is not Dummy: # callback itself was set + cb = param_callback + + # below: check for callback params + # don't set a parameter for non-existing callback + + # if the callback is lrcallback then initializa it with the train config, + # which is an input to the module + if name == "lrcallback": + params["config"] = self.main_config["train_config"] + else: + params = self.get_params_for("callbacks__{}".format(name)) + if (cb is None) and params: + raise ValueError( + "Trying to set a parameter for callback {} " + "which does not exist.".format(name) + ) + if cb is None: + continue + + if isinstance(cb, type): # uninitialized: + cb = cb(**params) + else: + cb.set_params(**params) + cb.initialize() + callbacks_.append((name, cb)) + + self.callbacks_ = callbacks_ + + return self diff --git a/transforna/src/novelty_prediction/__init__.py b/transforna/src/novelty_prediction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78f988b58d417903b6ef7db1af8a95b69ab92c09 --- /dev/null +++ b/transforna/src/novelty_prediction/__init__.py @@ -0,0 +1,2 @@ +from .id_vs_ood_entropy_clf import * +from .id_vs_ood_nld_clf import * diff --git a/transforna/src/novelty_prediction/id_vs_ood_entropy_clf.py b/transforna/src/novelty_prediction/id_vs_ood_entropy_clf.py new file mode 100644 index 0000000000000000000000000000000000000000..58239e68ae1a351c327d8c757d1c8bb8fe92345d --- /dev/null +++ b/transforna/src/novelty_prediction/id_vs_ood_entropy_clf.py @@ -0,0 +1,203 @@ + +#%% +#A script for classifying OOD vs HICO ID (test split). Generates results depicted in figure 4c +import json +import logging + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from imblearn.under_sampling import RandomUnderSampler +from scipy.stats import entropy +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split + +from ..utils.file import save +from ..utils.tcga_post_analysis_utils import Results_Handler +from .utlis import compute_prc, compute_roc + +logger = logging.getLogger(__name__) + +def entropy_clf(results,random_state:int=1): + #clf entropy is test vs ood if sub_class else vs loco_not_in_train + art_affix_ent = results.splits_df_dict["artificial_affix_df"]["Entropy"]["0"].values + + if results.trained_on == 'id': + test_ent = results.splits_df_dict["test_df"]["Entropy"]["0"].values + else: + test_ent = results.splits_df_dict["train_df"]["Entropy"]["0"].values[:int(0.25*len(results.splits_df_dict["train_df"]))] + ent_x = np.concatenate((art_affix_ent,test_ent)) + ent_labels = np.concatenate((np.zeros(art_affix_ent.shape),np.ones(test_ent.shape))) + trainX, testX, trainy, testy = train_test_split(ent_x, ent_labels, stratify=ent_labels,test_size=0.9, random_state=random_state) + + model = LogisticRegression(solver='lbfgs',class_weight='balanced') + + model.fit(trainX.reshape(-1, 1), trainy) + #balance testset + undersample = RandomUnderSampler(sampling_strategy='majority') + testX,testy = undersample.fit_resample(testX.reshape(-1,1),testy) + + # predict probabilities + lr_probs = model.predict_proba(testX) + # keep probabilities for the positive outcome only + lr_probs = lr_probs[:, 1] + yhat = model.predict(testX) + return testy,lr_probs,yhat,model + +def plot_entropy(results): + + entropies_list = [] + test_idx = 0 + for split_idx,split in enumerate(results.splits): + entropies_list.append(results.splits_df_dict[f"{split}_df"]["Entropy"]["0"].values) + if split == 'test': + test_idx = split_idx + + + bx = plt.boxplot(entropies_list) + plt.title("Entropy Distribution") + plt.ylabel("Entropy") + plt.xticks(np.arange(len(results.splits))+1,results.splits) + if results.save_results: + plt.savefig(f"{results.figures_path}/entropy_test_vs_ood_boxplot.png") + plt.savefig(f"{results.figures_path}/entropy_test_vs_ood_boxplot.svg") + plt.xticks(rotation=45) + plt.show() + return [item.get_ydata()[1] for item in bx['whiskers']][2*test_idx+1] + +def plot_entropy_per_unique_length(results,split): + seqs_len = results.splits_df_dict[f"{split}_df"]["RNA Sequences",'0'].str.len().values + index = results.splits_df_dict[f"{split}_df"]["RNA Sequences",'0'].values + entropies = results.splits_df_dict[f"{split}_df"]["Entropy","0"].values + + #create df for plotting + df = pd.DataFrame({"Entropy":entropies,"Sequences Length":seqs_len},index=index) + + + fig = df.boxplot(by='Sequences Length') + fig.get_figure().gca().set_title("") + fig.get_figure().gca().set_xlabel(f"Sequences Length ({split})") + fig.get_figure().gca().set_ylabel("Entropy") + plt.show() + if results.save_results: + plt.savefig(f"{results.figures_path}/{split}_entropy_per_length_boxplot.png") + +def plot_outliers(results,test_whisker_UB): + test_df = results.splits_df_dict["test_df"] + test_ent = test_df["Entropy", "0"] + #decompose outliers in ID + outlier_seqs = test_df.iloc[(test_whisker_UB < test_ent).values]['RNA Sequences']['0'].values + outlier_seqs_in_ad = list(set(results.dataset).intersection(set(outlier_seqs))) + major_class_dict = results.dataset.loc[outlier_seqs_in_ad]['small_RNA_class_annotation'][~results.dataset['hico'].isnull()].value_counts() + major_class_dict = {x:y for x,y in major_class_dict.items() if y!=0} + plt.pie(major_class_dict.values(),labels=major_class_dict.keys(),autopct='%1.1f%%') + plt.axis('equal') + plt.show() + plt.savefig(f"{results.figures_path}/Decomposition_outliers_in_test_pie.png") + #log outlier seqs to meta + if results.save_results: + save(data = outlier_seqs.tolist(),path=results.analysis_path+"/logit_outlier_seqs_ID.yaml") + +def log_model_params(model,analysis_path): + model_params = {"Model Coef": model.coef_[0][0],\ + "Model intercept": model.intercept_[0],\ + "Threshold": -model.intercept_[0]/model.coef_[0][0]} + model_params = eval(json.dumps(model_params)) + save(data = model_params,path=analysis_path+"/logits_model_coef.yaml") + + model.threshold = model_params["Threshold"] + +def compute_entropy_per_split(results:Results_Handler): + #compute entropy per split + for split in results.splits: + results.splits_df_dict[f"{split}_df"]["Entropy","0"] = entropy(results.splits_df_dict[f"{split}_df"]["Logits"].values,axis=1) + +def compute_novelty_prediction_per_split(results,model): + #add noovelty prediction for all splits + for split in results.splits: + results.splits_df_dict[f'{split}_df']['Novelty Prediction','is_known_class'] = results.splits_df_dict[f'{split}_df']['Entropy','0']<= model.threshold + +def compute_logits_clf_metrics(results): + aucs_roc = [] + aucs_prc = [] + f1s_prc = [] + replicates = 10 + show_figure: bool = False + for random_state in range(replicates): + #plot only for the last random seed + if random_state == replicates-1: + show_figure = True + #classify ID from OOD using entropy + test_labels,lr_probs,yhat,model = entropy_clf(results,random_state) + ###logs + if results.save_results: + log_model_params(model,results.analysis_path) + compute_novelty_prediction_per_split(results,model) + ###plots + auc_roc = compute_roc(test_labels,lr_probs,results,show_figure) + f1_prc,auc_prc = compute_prc(test_labels,lr_probs,yhat,results,show_figure) + aucs_roc.append(auc_roc) + aucs_prc.append(auc_prc) + f1s_prc.append(f1_prc) + + + auc_roc_score = sum(aucs_roc)/len(aucs_roc) + auc_roc_std = np.std(aucs_roc) + auc_prc_score = sum(aucs_prc)/len(aucs_prc) + auc_prc_std = np.std(aucs_prc) + f1_prc_score = sum(f1s_prc)/len(f1s_prc) + f1_prc_std = np.std(f1s_prc) + + logger.info(f"auc roc is {auc_roc_score} +- {auc_roc_std}") + logger.info(f"auc prc is {auc_prc_score} +- {auc_prc_std}") + logger.info(f"f1 prc is {f1_prc_score} +- {f1_prc_std}") + + logits_clf_metrics = {"AUC ROC score": auc_roc_score,\ + "auc_roc_std": auc_roc_std,\ + "AUC PRC score": auc_prc_score,\ + "auc_prc_std":auc_prc_std,\ + "F1 PRC score": f1_prc_score,\ + "f1_prc_std":f1_prc_std + } + + logits_clf_metrics = eval(json.dumps(logits_clf_metrics)) + if results.save_results: + save(data = logits_clf_metrics,path=results.analysis_path+"/logits_clf_metrics.yaml") + +def compute_entropies(embedds_path): + logger.info("Computing entropy for ID vs OOD:") + #######################################TO CONFIGURE############################################# + #embedds_path = ''#f'models/tcga/TransfoRNA_{trained_on.upper()}/sub_class/{model}/embedds' #edit path to contain path for the embedds folder, for example: transforna/results/seq-rev/embedds/ + splits = ['train','valid','test','ood','artificial','no_annotation'] + #run name + run_name = None #if None, then the name of the model inputs will be used as the name + #this could be for instance 'Sup Seq-Exp' + ################################################################################################ + results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=splits,read_dataset=True,save_results=True) + + results.append_loco_variants() + + results.splits[-1:-1] = ['artificial_affix','loco_not_in_train','loco_mixed','loco_in_train'] + + + if results.mc_flag: + results.splits.remove("ood") + + compute_entropy_per_split(results) + #remove train and valid from plotting entropy due to clutter + results.splits.remove("train") + results.splits.remove("valid") + + compute_logits_clf_metrics(results) + + test_whisker_UB = plot_entropy(results) + logger.info("plotting entropy per unique length") + plot_entropy_per_unique_length(results,'artificial_affix') + logger.info('plotting entropy per unique length for ood') + #decompose outliers in ID + logger.info("plotting outliers") + plot_outliers(results,test_whisker_UB) + + + +# %% diff --git a/transforna/src/novelty_prediction/id_vs_ood_nld_clf.py b/transforna/src/novelty_prediction/id_vs_ood_nld_clf.py new file mode 100644 index 0000000000000000000000000000000000000000..7dee43ecc6564f1f8afc1ddb0bef5b3bc7f88f30 --- /dev/null +++ b/transforna/src/novelty_prediction/id_vs_ood_nld_clf.py @@ -0,0 +1,222 @@ + +# %% +#A script for classifying OOD vs HICO ID (test split). Generates results depicted in figure 4c + +import json +import logging +from typing import List + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +from imblearn.under_sampling import RandomUnderSampler +from Levenshtein import distance +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split + +from ..utils.file import load, save +from ..utils.tcga_post_analysis_utils import Results_Handler +from .utlis import compute_prc, compute_roc + +logger = logging.getLogger(__name__) + +def get_lev_dist(seqs_a_list:List[str],seqs_b_list:List[str]): + ''' + compute levenstein distance between two lists of sequences and normalize by the length of the longest sequence + The lev distance is computed between seqs_a_list[i] and seqs_b_list[i] + ''' + lev_dist = [] + for i in range(len(seqs_a_list)): + dist = distance(seqs_a_list[i],seqs_b_list[i]) + #normalize + dist = dist/max(len(seqs_a_list[i]),len(seqs_b_list[i])) + lev_dist.append(dist) + return lev_dist + +def get_closest_neighbors(results:Results_Handler,query_embedds:np.ndarray,num_neighbors:int=1): + ''' + get the closest neighbors to the query embedds using the knn model in results + The closest neighbors are to be found in the training set + ''' + #norm infer embedds + query_embedds = query_embedds/np.linalg.norm(query_embedds,axis=1)[:,None] + #get top 1 seqs + distances, indices = results.knn_model.kneighbors(query_embedds) + distances = distances[:,:num_neighbors].flatten() + #flatten distances + + indices = indices[:,:num_neighbors] + + top_n_seqs = np.array(results.knn_seqs)[indices][:,:num_neighbors] + top_n_seqs = [seq[0] for sublist in top_n_seqs for seq in sublist] + top_n_labels = np.array(results.knn_labels)[indices][:,:num_neighbors] + top_n_labels = [label[0] for sublist in top_n_labels for label in sublist] + + return top_n_seqs,top_n_labels,distances + +def get_closest_ngbr_per_split(results:Results_Handler,split:str,num_neighbors:int=1): + ''' + compute levenstein distance between the sequences in split and their closest neighbors in the training set + ''' + split_df = results.splits_df_dict[f'{split}_df'] + #log + logger.debug(f'number of sequences in {split} is {split_df.shape[0]}') + #accomodate for multi-index df or single index + try: + split_seqs = split_df[results.seq_col].values[:,0] + except: + split_seqs = split_df[results.seq_col].values + try: + split_labels = split_df[results.label_col].values[:,0] + except: + split_labels = None + #get embedds + embedds = split_df[results.embedds_cols].values + + top_n_seqs,top_n_labels,distances = get_closest_neighbors(results,embedds,num_neighbors) + #get levenstein distance + #for each split_seqs duplicate it num_neighbors times + split_seqs = [seq for seq in split_seqs for _ in range(num_neighbors)] + lev_dist = get_lev_dist(split_seqs,top_n_seqs) + return split_seqs,split_labels,top_n_seqs,top_n_labels,distances,lev_dist + + +def log_lev_params(threshold:float,analysis_path:str): + model_params = {"Threshold": threshold} + model_params = eval(json.dumps(model_params)) + save(data = model_params,path=analysis_path+"/novelty_model_coef.yaml") + +def lev_clf(set_a,set_b,random_state): + #get labels + y = np.concatenate((np.zeros(len(set_a)),np.ones(len(set_b)))) + #get levenstein distance + lev_dist = np.concatenate((set_a,set_b)) + #upsample minority class + oversample = RandomUnderSampler(sampling_strategy='majority',random_state=random_state) + lev_dist, y = oversample.fit_resample(lev_dist.reshape(-1,1), y) + #get levenstein distance as a feature + lev_dist = lev_dist.reshape(-1,1) + #split to train and test + X_train, X_test, y_train, y_test = train_test_split(lev_dist, y, test_size=0.33, random_state=random_state) + #define model + model = LogisticRegression(solver='lbfgs') + #fit model + model.fit(X_train, y_train) + #predict probabilities + lr_probs = model.predict_proba(X_test)[:, 1] + #predict class + yhat = model.predict(X_test) + return y_test,lr_probs,yhat,model + + +def compute_novelty_clf_metrics(results:Results_Handler,lev_dist_id_set,lev_dist_ood_set): + aucs_roc = [] + aucs_prc = [] + f1s_prc = [] + thresholds = [] + replicates = 10 + show_figure: bool = False + + for random_state in range(replicates): + #plot only for the last random seed + if random_state == replicates-1: + show_figure = True + #classify ID from OOD using entropy + test_labels,lr_probs,yhat,model = lev_clf(lev_dist_id_set,lev_dist_ood_set,random_state) + thresholds.append(-model.intercept_[0]/model.coef_[0][0]) + mean_thresh = sum(thresholds)/len(thresholds) + ###logs + if results.save_results: + log_lev_params(mean_thresh,results.analysis_path) + ###plots + auc_roc = compute_roc(test_labels,lr_probs,results,show_figure) + f1_prc,auc_prc = compute_prc(test_labels,lr_probs,yhat,results,show_figure) + aucs_roc.append(auc_roc) + aucs_prc.append(auc_prc) + f1s_prc.append(f1_prc) + + + auc_roc_score = sum(aucs_roc)/len(aucs_roc) + auc_roc_std = np.std(aucs_roc) + auc_prc_score = sum(aucs_prc)/len(aucs_prc) + auc_prc_std = np.std(aucs_prc) + f1_prc_score = sum(f1s_prc)/len(f1s_prc) + f1_prc_std = np.std(f1s_prc) + + logger.info(f"auc roc is {auc_roc_score} +- {auc_roc_std}") + logger.info(f"auc prc is {auc_prc_score} +- {auc_prc_std}") + logger.info(f"f1 prc is {f1_prc_score} +- {f1_prc_std}") + + novelty_clf_metrics = {"AUC ROC score": auc_roc_score,\ + "auc_roc_std": auc_roc_std,\ + "AUC PRC score": auc_prc_score,\ + "auc_prc_std":auc_prc_std,\ + "F1 PRC score": f1_prc_score,\ + "f1_prc_std":f1_prc_std + } + + novelty_clf_metrics = eval(json.dumps(novelty_clf_metrics)) + if results.save_results: + save(data = novelty_clf_metrics,path=results.analysis_path+"/novelty_clf_metrics.yaml") + + return sum(thresholds)/len(thresholds) + + +def compute_nlds(embedds_path): + logger.info("Computing NLD metrics") + #######################################TO CONFIGURE############################################# + logger.info("Computing novelty clf metrics") + #embedds_path = ''#f'models/tcga/TransfoRNA_{trained_on.upper()}/sub_class/{model}/embedds' #edit path to contain path for the embedds folder, for example: transforna/results/seq-rev/embedds/ + splits = ['train','valid','test','ood','artificial','no_annotation'] + #run name + run_name = None #if None, then the name of the model inputs will be used as the name + #this could be for instance 'Sup Seq-Exp' + ################################################################################################ + results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=splits,read_dataset=True,create_knn_graph=True,save_results=True) + results.append_loco_variants() + #get knn model + results.get_knn_model() + lev_dist_df = pd.DataFrame() + + + #compute levenstein distance per split + for split in results.splits_df_dict.keys(): + if len(results.splits_df_dict[f'{split}']) == 0: + continue + split_seqs,split_labels,top_n_seqs,top_n_labels,distances,lev_dist = get_closest_ngbr_per_split(results,'_'.join(split.split('_')[:-1])) + #create df from split and levenstein distance + lev_dist_split_df = pd.DataFrame({'split':split,'lev_dist':lev_dist,'seqs':split_seqs,'labels':split_labels,'top_n_seqs':top_n_seqs,'top_n_labels':top_n_labels}) + #append to lev_dist_df + lev_dist_df = lev_dist_df.append(lev_dist_split_df) + + #plot boxplot levenstein distance per split using plotly and add seqs and labels + fig = px.box(lev_dist_df, x="split", y="lev_dist",points="all",hover_data=['seqs','labels','top_n_seqs','top_n_labels']) + #reduce marker size + fig.update_traces(marker=dict(size=2)) + fig.show() + #save as html file in figures_path + fig.write_html(f'{results.figures_path}/lev_distance_distribution.html') + fig.write_image(f'{results.figures_path}/lev_distance_distribution.png') + + #get rows of lev_dist_df from ood/artificial_affix split and from test split + if 'ood_df' in lev_dist_df['split'].values: #for ID models + novel_df = lev_dist_df[lev_dist_df['split'] == 'ood_df'] + else:#for FULL models as all classes are used for training: no OOD + novel_df = lev_dist_df[lev_dist_df['split'] == 'artificial_affix_df'] + test_df = lev_dist_df[lev_dist_df['split'] == 'test_df'] + + lev_dist_df.to_csv(f'{results.analysis_path}/lev_dist_df.csv') + + #compute novelty clf metrics + compute_novelty_clf_metrics(results,test_df['lev_dist'].values,novel_df['lev_dist'].values) + + + + + + + + + + diff --git a/transforna/src/novelty_prediction/readme.md b/transforna/src/novelty_prediction/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..b189a8f18262be2652934bb03f9db9f9a3c0cdac --- /dev/null +++ b/transforna/src/novelty_prediction/readme.md @@ -0,0 +1,33 @@ +## Folder Structure +- tcga scripts contains the novelty prediciton: +- There are two approaches for novelty introduced. NLD and Entropy + - `id_vs_ood_entropy_clf.py`: learns a threshold discerning Familiar vs Novel sequences based on entropy. + - `id_vs_ood_nld_clf.py`: learns a threshold discerning Familiar vs Novel sequences based on normalized levenstein distance. + + + +## Results Handler +Results Handler is a class that permiates handling a given run. For instance, the specific sequence splits within the run directory could be selectively read, the anndata the model used for training, the configurations used, the creation of the knn graph which the novelty prediction is based on, and the computation of the umaps. + +To use `ResultsHandler`, make sure to edit the `path` in the `main` function. This path should point to the `embedds` folder of a given run. for example, change: + +``` +path = None +``` + +to + +``` +path = /path/to/embedds/ +``` + +results would be generated but not saved. If the results are required to be saved, change + +``` +results:Results_Handler = Results_Handler(path=path,splits=splits) +``` +to +``` +results:Results_Handler = Results_Handler(path=path,splits=splits,save_results=True) +``` +This will save the results in the same level as the `embedds` folder of a given run. diff --git a/transforna/src/novelty_prediction/utlis.py b/transforna/src/novelty_prediction/utlis.py new file mode 100644 index 0000000000000000000000000000000000000000..d4c661431e65714c899317d46b0319ce22bb1309 --- /dev/null +++ b/transforna/src/novelty_prediction/utlis.py @@ -0,0 +1,60 @@ + +import matplotlib.pyplot as plt +from matplotlib import pyplot +from sklearn.metrics import (auc, f1_score, precision_recall_curve, + roc_auc_score, roc_curve) + +from ..utils.tcga_post_analysis_utils import Results_Handler + + +def compute_prc(test_labels,lr_probs,yhat,results:Results_Handler,show_figure:bool=False): + + lr_precision, lr_recall, _ = precision_recall_curve(test_labels, lr_probs) + lr_f1, lr_auc = f1_score(test_labels, yhat), auc(lr_recall, lr_precision) + # plot the precision-recall curves + if show_figure: + pyplot.plot(lr_recall, lr_precision, marker='.', label=results.figures_path.split('/')[-2]) + # axis labels + pyplot.xlabel('Recall') + pyplot.ylabel('Precision') + # show the legend + pyplot.legend() + # save and show the plot + plt.title("PRC Curve") + + if results.save_results: + plt.savefig(f"{results.figures_path}/prc_curve.png") + plt.savefig(f"{results.figures_path}/prc_curve.svg") + + if show_figure: + plt.show() + return lr_f1,lr_auc + +def compute_roc(test_labels,lr_probs,results,show_figure:bool=False): + + ns_probs = [0 for _ in range(len(test_labels))] + + # calculate scores + ns_auc = roc_auc_score(test_labels, ns_probs) + lr_auc = roc_auc_score(test_labels, lr_probs) + # calculate roc curves + ns_fpr, ns_tpr, _ = roc_curve(test_labels, ns_probs) + lr_fpr, lr_tpr, _ = roc_curve(test_labels, lr_probs) + + # plot the roc curve for the model + if show_figure: + plt.plot(lr_fpr, lr_tpr, marker='.',markersize=1, label=results.figures_path.split('/')[-2]) + # axis labels + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + # show the legend + plt.legend() + plt.title("ROC Curve") + + if results.save_results: + plt.savefig(f"{results.figures_path}/roc_curve.png") + plt.savefig(f"{results.figures_path}/roc_curve.svg") + + if show_figure: + plt.show() + return lr_auc diff --git a/transforna/src/processing/__init__.py b/transforna/src/processing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2375da73fbe052d75712a0ef64dfed6a8899a3e --- /dev/null +++ b/transforna/src/processing/__init__.py @@ -0,0 +1,3 @@ +from .augmentation import * +from .seq_tokenizer import * +from .splitter import * diff --git a/transforna/src/processing/augmentation.py b/transforna/src/processing/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..f1003aceea795d1898a24ad5370bcddec22812b3 --- /dev/null +++ b/transforna/src/processing/augmentation.py @@ -0,0 +1,497 @@ +import logging +import random +from contextlib import redirect_stdout +from pathlib import Path +from random import randint +from typing import Dict, List, Tuple + +import numpy as np +import pandas as pd + +from ..novelty_prediction.id_vs_ood_nld_clf import get_closest_ngbr_per_split +from ..utils.energy import fold_sequences +from ..utils.file import load +from ..utils.tcga_post_analysis_utils import Results_Handler +from ..utils.utils import (get_model, infer_from_pd, + prepare_inference_results_tcga, + update_config_with_inference_params) +from .seq_tokenizer import SeqTokenizer + +logger = logging.getLogger(__name__) + +class IDModelAugmenter: + ''' + This class is used to augment the dataset with the predictions of the ID models + It will first predict the subclasses of the NA set using the ID models + Then it will compute the levenstein distance between the sequences of the NA set and the closest neighbor in the training set + If the levenstein distance is less than a threshold, the sequence is considered familiar + ''' + def __init__(self,df:pd.DataFrame,config:Dict): + self.df = df + self.config = config + self.mapping_dict = load(config['train_config']['mapping_dict_path']) + + + def predict_transforna_na(self) -> Tuple: + infer_pd = pd.DataFrame(columns=['Sequence','Net-Label','Is Familiar?']) + + mc_or_sc = 'major_class' if 'major_class' in self.config['model_config']['clf_target'] else 'sub_class' + inference_config = update_config_with_inference_params(self.config,mc_or_sc=mc_or_sc,path_to_models=self.config['path_to_models']) + model_path = inference_config['inference_settings']["model_path"] + logger.info(f"Augmenting hico sequences based on predictions from model at: {model_path}") + + #path should be infer_cfg["model_path"] - 2 level + embedds + embedds_path = '/'.join(inference_config['inference_settings']["model_path"].split('/')[:-2])+'/embedds' + #read threshold + results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train','no_annotation']) + results.get_knn_model() + threshold = load(results.analysis_path+"/novelty_model_coef")["Threshold"] + sequences = results.splits_df_dict['no_annotation_df'][results.seq_col].values[:,0] + with redirect_stdout(None): + root_dir = Path(__file__).parents[3].absolute() + inference_config, net = get_model(inference_config, root_dir) + #original_infer_pd might include seqs that are longer than input model. if so, infer_pd contains the trimmed sequences + original_infer_pd = pd.Series(sequences, name="Sequences").to_frame() + logger.info(f'predicting sub classes for the NA set by the ID models') + predicted_labels, logits,_, _,all_data, max_len, net, infer_pd = infer_from_pd(inference_config, net, original_infer_pd, SeqTokenizer) + + + prepare_inference_results_tcga(inference_config,predicted_labels, logits, all_data, max_len) + infer_pd = all_data["infere_rna_seq"] + + #compute lev distance for embedds and + logger.info('computing levenstein distance for the NA set by the ID models') + _,_,_,_,_,lev_dist = get_closest_ngbr_per_split(results,'no_annotation') + + logger.info(f'num of hico based on entropy novelty prediction is {sum(infer_pd["Is Familiar?"])}') + infer_pd['Is Familiar?'] = [True if lv= 1].index.tolist() + #get one sequence per label in one line + samples = [self.df[self.df['Labels'] == label].sample(1).index[0] for label in unique_labels] + #makes number of samples even + if len(samples) % 2 != 0: + samples = samples[:-1] + np.random.shuffle(samples) + #split samples into two sets + samples_set1 = samples[:len(samples)//2] + samples_set2 = samples[len(samples)//2:] + #create fusion set + recombined_set = [] + for i in range(len(samples_set1)): + recombined_seq = samples_set1[i]+samples_set2[i] + #get index of the first ntd of the second sequence + recombined_index = len(samples_set1[i]) + #sample a random offset -5 and 5 + offset = randint(-5,5) + recombined_index += offset + #sample an int between 18 and 30 + random_half_len = int(randint(18,30)/2) #9 to 15 + #get the sequence from the recombined sequence + random_seq = recombined_seq[max(0,recombined_index - random_half_len):recombined_index + random_half_len] + recombined_set.append(random_seq) + + recombined_df = pd.DataFrame(index=recombined_set, data=[f'{class_label}']*len(recombined_set)\ + , columns =['Labels']) + + return recombined_df + + def get_augmented_df(self): + recombined_df = self.create_recombined_seqs() + return recombined_df + +class RandomSeqAugmenter: + ''' + This class is used to augment the dataset with random sequences within the same length range as the tcga sequences + ''' + def __init__(self,df:pd.DataFrame,config:Dict): + self.df = df + self.config = config + self.num_seqs = 500 + self.min_len = 18 + self.max_len = 30 + + def get_random_seq(self): + #create random sequences from bases: A,C,G,T with length 18-30 + random_seqs = [] + while len(random_seqs) < self.num_seqs: + random_seq = ''.join(random.choices(['A','C','G','T'], k=randint(self.min_len,self.max_len))) + if random_seq not in random_seqs and random_seq not in self.df.index: + random_seqs.append(random_seq) + + return pd.DataFrame(index=random_seqs, data=['random']*len(random_seqs)\ + , columns =['Labels']) + def get_augmented_df(self): + random_df = self.get_random_seq() + return random_df + +class PrecursorAugmenter: + def __init__(self,df:pd.DataFrame, config:Dict): + self.df = df + self.config = config + self.mapping_dict = load(config['train_config'].mapping_dict_path) + self.precursor_df = self.load_precursor_file() + self.trained_on = config.trained_on + + self.min_num_samples_per_sc:int=1 + if self.trained_on == 'id': + self.min_num_samples_per_sc = 8 + + self.min_bin_size = 20 + self.max_bin_size = 30 + self.min_seq_len = 18 + self.max_seq_len = 30 + + def load_precursor_file(self): + try: + precursor_df = pd.read_csv(self.config['train_config'].precursor_file_path, index_col=0) + return precursor_df + except: + logger.info('Could not load precursor file') + return None + + def compute_dynamic_bin_size(self,precursor_len:int, name:str=None) -> List[int]: + ''' + This function splits precursor to bins of size max_bin_size + if the last bin is smaller than min_bin_size, it will split the precursor to bins of size max_bin_size-1 + This process will continue until the last bin is larger than min_bin_size. + if the min bin size is reached and still the last bin is smaller than min_bin_size, the last two bins will be merged. + so the maximimum bin size possible would be min_bin_size+(min_bin_size-1) = 39 + ''' + def split_precursor_to_bins(precursor_len,max_bin_size): + ''' + This function splits precursor to bins of size max_bin_size + ''' + precursor_bin_lens = [] + for i in range(0, precursor_len, max_bin_size): + if i+max_bin_size < precursor_len: + precursor_bin_lens.append(max_bin_size) + else: + precursor_bin_lens.append(precursor_len-i) + return precursor_bin_lens + + if precursor_len < self.min_bin_size: + return [precursor_len] + else: + precursor_bin_lens = split_precursor_to_bins(precursor_len,self.max_bin_size) + reduced_len = self.max_bin_size-1 + while precursor_bin_lens[-1] < self.min_bin_size: + precursor_bin_lens = split_precursor_to_bins(precursor_len,reduced_len) + reduced_len -= 1 + if reduced_len < self.min_bin_size: + #add last two bins together + precursor_bin_lens[-2] += precursor_bin_lens[-1] + precursor_bin_lens = precursor_bin_lens[:-1] + break + + return precursor_bin_lens + + def get_bin_with_max_overlap(self,precursor_len:int,start_frag_pos:int,frag_len:int,name) -> int: + ''' + This function returns the bin number of a fragment that overlaps the most with the fragment + ''' + precursor_bin_lens = self.compute_dynamic_bin_size(precursor_len=precursor_len,name=name) + bin_no = 0 + for i,bin_len in enumerate(precursor_bin_lens): + if start_frag_pos < bin_len: + #get overlap with curr bin + overlap = min(bin_len-start_frag_pos,frag_len) + + if overlap > frag_len/2: + bin_no = i + else: + bin_no = i+1 + break + + else: + start_frag_pos -= bin_len + return bin_no+1 + + def get_precursor_info(self,mc:str,sc:str): + + xRNA_df = self.precursor_df.loc[self.precursor_df.small_RNA_class_annotation == mc] + xRNA_df.index = xRNA_df.index.str.replace('|','-', regex=False) + prec_name = sc.split('_bin-')[0] + + if mc in ['snoRNA','lncRNA','protein_coding','miscRNA']: + prec_name = mc+'-'+prec_name + prec_row_df = xRNA_df.iloc[xRNA_df.index.str.contains(prec_name)] + #check if prec_row_df is empty + if prec_row_df.empty: + xRNA_df = self.precursor_df.loc[self.precursor_df.small_RNA_class_annotation == 'pseudo_'+mc] + xRNA_df.index = xRNA_df.index.str.replace('|','-', regex=False) + prec_row_df = xRNA_df.iloc[xRNA_df.index.str.contains(prec_name)] + if prec_row_df.empty: + logger.info(f'precursor {prec_name} not found in HBDxBase') + return pd.DataFrame() + + prec_row_df = prec_row_df.iloc[0] + else: + prec_row_df = xRNA_df.loc[f'{mc}-{prec_name}'] + + precursor = prec_row_df.sequence + return precursor,prec_name + + def populate_from_bin(self,sc:str,precursor:str,prec_name:str,existing_seqs:List[str]): + ''' + This function will first get the bin no from the sc. + Then it will do three types of sampling: + 1. sample from the previous bin, insuring that the overlap with the middle bin is the highest + 2. sample from the next bin, insuring that the overlap with the middle bin is the highest + 3. sample from the middle bin, insuring that the overlap with the middle bin is the highest + The staet idx should be the middle position of the previous bin, then start position is incremented until the end of the current bin + ''' + bin_no = int(sc.split('_bin-')[1]) + bins = self.compute_dynamic_bin_size(len(precursor), prec_name) + if len(bins) == 1: + return pd.DataFrame() + + #bins start from 1 so should subtract 1 + bin_no -= 1 + + #in case bin_no is 0 + try: + previous_bin_start = sum(bins[:bin_no-1]) + except: + previous_bin_start = 0 + middle_bin_start = sum(bins[:bin_no]) + next_bin_start = sum(bins[:bin_no+1]) + + + try: + previous_bin_size = bins[bin_no-1] + except: + previous_bin_size = 0 + + middle_bin_size = bins[bin_no] + try: + next_bin_size = bins[bin_no+1] + except: + next_bin_size = 0 + + + start_idx = previous_bin_start + previous_bin_size//2 + 1 #+1 to make sure max overlap with prev bin is 14. max len/2 - 1 + sampled_seqs = [] + #increase start idx until the end of the current bin + while start_idx < middle_bin_start+middle_bin_size: + #compute the boundaries of the length of the fragment so that it would always overlap with the middle bin the most + if start_idx < middle_bin_start: + max_overlap_prev = middle_bin_start - start_idx + end_idx = start_idx + randint(max(self.min_seq_len,max_overlap_prev*2+1),self.max_seq_len) + else:# start_idx >= middle_bin_start: + max_overlap_curr = next_bin_start - start_idx + max_overlap_next = (start_idx + self.max_seq_len) - next_bin_start + max_overlap_next = min(max_overlap_next,next_bin_size) + if max_overlap_curr <= 9 or (max_overlap_next==0 and max_overlap_curr < self.min_seq_len): + end_idx = -1 + else: + end_idx = start_idx + randint(self.min_seq_len,min(self.max_seq_len,self.max_seq_len - max_overlap_next + max_overlap_curr - 1)) + #max overlap with the middle bin will never exceed half of min fragment (9) or, + # next bin size is 0 so frag will be shorter than 18 + if end_idx == -1: + break + + tmp_seq = precursor[start_idx:end_idx] + #introduce mismatches + assert len(tmp_seq) >= self.min_seq_len and len(tmp_seq) <= self.max_seq_len, f'length of tmp_seq is {len(tmp_seq)}' + if tmp_seq not in existing_seqs: + sampled_seqs.append(tmp_seq) + start_idx += 1 + + #assertions + for frag in sampled_seqs: + all_occ = precursor.find(frag) + if not isinstance(all_occ,list): + all_occ = [all_occ] + + for occ in all_occ: + curr_bin_no = self.get_bin_with_max_overlap(len(precursor),occ,len(frag),' ') + # if curr_bin_no is different from bin_no+1 with more than 2 skip assertion + if abs(curr_bin_no - (bin_no+1)) > 1: + continue + assert curr_bin_no == bin_no+1, f'curr_bin_no is {curr_bin_no} and bin_no is {bin_no+1}' + + return pd.DataFrame(index=sampled_seqs, data=[sc]*len(sampled_seqs)\ + , columns =['Labels']) + + def populate_scs_with_bins(self): + augmented_df = pd.DataFrame(columns=['Labels']) + + #append samples per sc for bin continuity + unique_labels = self.df.Labels.value_counts()[self.df.Labels.value_counts() >= self.min_num_samples_per_sc].index.tolist() + scs_list = [] + scs_before = [] + sc_after = [] + for sc in unique_labels: + #retrieve_bin_from_precursor(other_sc_df,mapping_dict,sc) + if type(sc) == str and '_bin-' in sc: + #get mc + try: + mc = self.mapping_dict[sc] + except: + sc_mc_mapper = lambda x: 'miRNA' if 'miR' in x else 'tRNA' if 'tRNA' in x else 'rRNA' if 'rRNA' in x else 'snRNA' if 'snRNA' in x else 'snoRNA' if 'snoRNA' in x else 'snoRNA' if 'SNO' in x else 'protein_coding' if 'RPL37A' in x else 'lncRNA' if 'SNHG1' in x else None + mc = sc_mc_mapper(sc) + if mc is None: + logger.info(f'No mapping for {sc}') + continue + existing_seqs = self.df[self.df['Labels'] == sc].index + scs_list.append(sc) + scs_before.append(len(existing_seqs)) + #augment fragments from prev or consecutive bin + precursor,prec_name = self.get_precursor_info(mc,sc) + sc2_df = self.populate_from_bin(sc,precursor,prec_name,existing_seqs) + augmented_df = augmented_df.append(sc2_df) + sc_after.append(len(sc2_df)) + #make a dict of scs and number of samples before and after augmentation + scs_dict = {'sub_class':scs_list,'Number of samples before':scs_before,'Number of samples afrer':sc_after} + scs_df = pd.DataFrame(scs_dict) + scs_df.to_csv(f'frequency_per_sub_class_df.csv') + + + return augmented_df + + def get_augmented_df(self): + return self.populate_scs_with_bins() + +class DataAugmenter: + ''' + This class sets the labels of the dataset to major class or sub class labels based on the clf_target + major class: miRNA, tRNA ... + sub class: mir-192-3p, rRNA-bin-30 ... + Then if the models should be tained on ID models, it will augment the dataset with sequences sampled from the precursor file + If the models should be trained on full, it will augment the dataset based on the following: + 1. Random sequences + 2. Recombined sequences + 3. Sequences sampled from the precursor file + 4. predictions of the sequences that previously had no annotation of low confidence but were predicted to be familiar by the ID models + ''' + def __init__(self,df:pd.DataFrame, config:Dict): + self.df = df + self.config = config + self.mapping_dict = load(config['train_config'].mapping_dict_path) + self.trained_on = config.trained_on + self.clf_target = config['model_config'].clf_target + logger.info(f'Augmenting the dataset for {self.clf_target}') + self.set_labels() + + self.precursor_augmenter = PrecursorAugmenter(self.df,self.config) + self.random_augmenter = RandomSeqAugmenter(self.df,self.config) + self.recombined_augmenter = RecombinedSeqAugmenter(self.df,self.config) + self.id_model_augmenter = IDModelAugmenter(self.df,self.config) + + + + def set_labels(self): + if 'hico' not in self.clf_target: + self.df['Labels'] = self.df['subclass_name'].str.split(';', expand=True)[0] + else: + self.df['Labels'] = self.df['subclass_name'][self.df['hico'] == True] + + self.df['Labels'] = self.df['Labels'].astype('category') + + + def convert_to_major_class_labels(self): + if 'major_class' in self.clf_target: + self.df['Labels'] = self.df['Labels'].map(self.mapping_dict).astype('category') + #remove multitarget major classes + self.df = self.df[~self.df['Labels'].str.contains(';').fillna(False)] + + + def combine_df(self,new_var_df:pd.DataFrame): + #remove any sequences in augmented_df that exist in self.df.indexs + duplicated_df = new_var_df[new_var_df.index.isin(self.df.index)] + #log + if len(duplicated_df): + logger.info(f'Number of duplicated sequences to be removed augmented data: {duplicated_df.shape[0]}') + + new_var_df = new_var_df[~new_var_df.index.isin(self.df.index)].sample(frac=1) + + for col in self.df.columns: + if col not in new_var_df.columns: + new_var_df[col] = np.nan + + self.df = new_var_df.append(self.df) + self.df.index = self.df.index.str.upper() + self.df.Labels = self.df.Labels.astype('category') + return self.df + + + def annotate_artificial_affix_seqs(self): + #AA seqs are sequences that have 5' adapter + aa_seqs = self.df[self.df['five_prime_adapter_filter'] == 0].index.tolist() + self.df['Labels'] = self.df['Labels'].cat.add_categories('artificial_affix') + self.df.loc[aa_seqs,'Labels'] = 'artificial_affix' + + + + def full_pipeline(self): + self.df = self.id_model_augmenter.get_augmented_df() + + + def post_augmentation(self): + random_df = self.random_augmenter.get_augmented_df() + #augmentation is only done for sub_class + if 'sub_class' in self.clf_target: + df = self.precursor_augmenter.get_augmented_df() + else: + df = pd.DataFrame() + recombined_df = self.recombined_augmenter.get_augmented_df() + df = df.append(recombined_df).append(random_df) + self.df['Labels'] = self.df['Labels'].cat.add_categories({'random','recombined'}) + self.combine_df(df) + + self.convert_to_major_class_labels() + self.annotate_artificial_affix_seqs() + self.df['Labels'] = self.df['Labels'].cat.remove_unused_categories() + self.df['Sequences'] = self.df.index.tolist() + + if 'struct' in self.config['model_config'].model_input: + self.df['Secondary'] = fold_sequences(self.df.index.tolist(),temperature=37)[f'structure_37'].values + + return self.df + + def get_augmented_df(self): + if self.trained_on == 'full': + self.full_pipeline() + return self.post_augmentation() \ No newline at end of file diff --git a/transforna/src/processing/seq_tokenizer.py b/transforna/src/processing/seq_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6ff369dc6cc1fc4596870708da8ace27865c21 --- /dev/null +++ b/transforna/src/processing/seq_tokenizer.py @@ -0,0 +1,350 @@ + +import logging +import math +import os +import warnings +from random import randint + +import numpy as np +import pandas as pd +from numpy.lib.stride_tricks import as_strided +from omegaconf import DictConfig, open_dict + +from ..utils import energy +from ..utils.file import save + +logger = logging.getLogger(__name__) + +class SeqTokenizer: + ''' + This class should contain functions that other data specific classes should inherit from. + ''' + def __init__(self,seqs_dot_bracket_labels: pd.DataFrame, config: DictConfig): + + self.seqs_dot_bracket_labels = seqs_dot_bracket_labels.reset_index(drop=True) + #shuffle + if not config["inference"]: + self.seqs_dot_bracket_labels = self.seqs_dot_bracket_labels\ + .sample(frac=1)\ + .reset_index(drop=True) + + #get input of model + self.model_input = config["model_config"].model_input + + + # set max length to be <= 2 stds of distribtion of lengths + if config["train_config"].filter_seq_length: + self.get_outlier_length_threshold() + self.limit_seqs_to_range() + + else: + self.max_length = self.seqs_dot_bracket_labels['Sequences'].str.len().max() + self.min_length = 0 + + with open_dict(config): + config["model_config"]["max_length"] = np.int64(self.max_length).item() + config["model_config"]["min_length"] = np.int64(self.min_length).item() + + self.window = config["model_config"].window + self.tokens_len = math.ceil(self.max_length / self.window) + if config["model_config"].tokenizer in ["overlap", "overlap_multi_window"]: + self.tokens_len = int(self.max_length - (config["model_config"].window - 1)) + self.tokenizer = config["model_config"].tokenizer + + + self.seq_len_dist = self.seqs_dot_bracket_labels['Sequences'].str.len().value_counts() + #init tokens dict + self.seq_tokens_ids_dict = {} + self.second_input_tokens_ids_dict = {} + + #get and set number of labels in config to be used later by the model + config["model_config"].num_classes = len(self.seqs_dot_bracket_labels['Labels'].unique()) + + self.set_class_attr() + + + def get_outlier_length_threshold(self): + lengths_arr = self.seqs_dot_bracket_labels['Sequences'].str.len() + mean = np.mean(lengths_arr) + standard_deviation = np.std(lengths_arr) + distance_from_mean = abs(lengths_arr - mean) + in_distribution = distance_from_mean < 2 * standard_deviation + + inlier_lengths = np.sort(lengths_arr[in_distribution].unique()) + self.max_length = int(np.max(inlier_lengths)) + self.min_length = int(np.min(inlier_lengths)) + logger.info(f'maximum and minimum sequence length is set to: {self.max_length} and {self.min_length}') + return + + + def limit_seqs_to_range(self): + ''' + Trimms seqs longer than maximum len and deletes seqs shorter than min length + ''' + df = self.seqs_dot_bracket_labels + min_to_be_deleted = [] + + num_longer_seqs = sum(df['Sequences'].str.len()>self.max_length) + if num_longer_seqs: + logger.info(f"Number of sequences to be trimmed: {num_longer_seqs}") + + + for idx,seq in enumerate(df['Sequences']): + if len(seq) > self.max_length: + df['Sequences'].iloc[idx] = \ + df['Sequences'].iloc[idx][:self.max_length] + + elif len(seq) < self.min_length: + #deleted sequence indices + min_to_be_deleted.append(str(idx)) + #delete min sequences + if len(min_to_be_deleted): + df = df.drop(min_to_be_deleted).reset_index(drop=True) + logger.info(f"Number of sequences shroter sequences to be removed: {len(min_to_be_deleted)}") + self.seqs_dot_bracket_labels = df + + def get_secondary_structure(self,sequences): + secondary = energy.fold_sequences(sequences.tolist()) + return secondary['structure_37'].values + + # function generating non overlapping tokens of a feature sample + def chunkstring_overlap(self, string, window): + return ( + string[0 + i : window + i] for i in range(0, len(string) - window + 1, 1) + ) + # function generating non overlapping tokens of a feature sample + def chunkstring_no_overlap(self, string, window): + return (string[0 + i : window + i] for i in range(0, len(string), window)) + + + def tokenize_samples(self, window:int,sequences_to_be_tokenized:pd.DataFrame,inference:bool=False,tokenizer:str="overlap") -> np.ndarray: + """ + This function tokenizes rnas based on window(window) + with or without overlap according to the current tokenizer option. + In case of overlap: + example: Token :AACTAGA, window: 3 + output: AAC,ACT,CTA,TAG,AGA + + In case no_overlap: + example: Token :AACTAGA, window: 3 + output: AAC,TAG,A + """ + # get feature tokens + if "overlap" in tokenizer: + feature_tokens_gen = list( + self.chunkstring_overlap(feature, window) + for feature in sequences_to_be_tokenized + ) + elif tokenizer == "no_overlap": + feature_tokens_gen = list( + self.chunkstring_no_overlap(feature, window) for feature in sequences_to_be_tokenized + ) + # get sample tokens and pad them + samples_tokenized = [] + sample_token_ids = [] + if not self.seq_tokens_ids_dict: + self.seq_tokens_ids_dict = {"pad": 0} + + for gen in feature_tokens_gen: + sample_token_id = [] + sample_token = list(gen) + sample_len = len(sample_token) + # append paddings + sample_token.extend( + ["pad" for _ in range(int(self.tokens_len - sample_len))] + ) + # convert tokens to ids + for token in sample_token: + # if token doesnt exist in dict, create one + if token not in self.seq_tokens_ids_dict: + if not inference: + id = len(self.seq_tokens_ids_dict.keys()) + self.seq_tokens_ids_dict[token] = id + else: + #if new token found during inference, then select random token (considered as noise) + logger.warning(f"The sequence token: {token} was not seen previously by the model. Token will be replaced by a random token") + id = randint(1,len(self.seq_tokens_ids_dict.keys()) - 1) + token = self.seq_tokens_ids_dict[id] + # append id of token + sample_token_id.append(self.seq_tokens_ids_dict[token]) + + # append ids of tokenized sample + sample_token_ids.append(np.array(sample_token_id)) + + sample_token = np.array(sample_token) + samples_tokenized.append(sample_token) + + return (np.array(samples_tokenized), np.array(sample_token_ids)) + + def tokenize_secondary_structure(self, window,sequences_to_be_tokenized,inference:bool=False,tokenizer= "overlap") -> np.ndarray: + """ + This function tokenizes rnas based on window(window) + with or without overlap according to the current tokenizer option. + In case of overlap: + example: Token :...()..., window: 3 + output: ...,..(,.(),().,)..,... + + In case no_overlap: + example: Token :...()..., window: 3 + output: ...,().,.. + """ + samples_tokenized = [] + sample_token_ids = [] + if not self.second_input_tokens_ids_dict: + self.second_input_tokens_ids_dict = {"pad": 0} + + # get feature tokens + if "overlap" in tokenizer: + feature_tokens_gen = list( + self.chunkstring_overlap(feature, window) + for feature in sequences_to_be_tokenized + ) + elif "no_overlap" == tokenizer: + feature_tokens_gen = list( + self.chunkstring_no_overlap(feature, window) for feature in sequences_to_be_tokenized + ) + # get sample tokens and pad them + for seq_idx, gen in enumerate(feature_tokens_gen): + sample_token_id = [] + sample_token = list(gen) + + # convert tokens to ids + for token in sample_token: + # if token doesnt exist in dict, create one + if token not in self.second_input_tokens_ids_dict: + if not inference: + id = len(self.second_input_tokens_ids_dict.keys()) + self.second_input_tokens_ids_dict[token] = id + else: + #if new token found during inference, then select random token (considered as noise) + warnings.warn(f"The secondary structure token: {token} was not seen previously by the model. Token will be replaced by a random token") + id = randint(1,len(self.second_input_tokens_ids_dict.keys()) - 1) + token = self.second_input_tokens_ids_dict[id] + # append id of token + sample_token_id.append(self.second_input_tokens_ids_dict[token]) + # append ids of tokenized sample + sample_token_ids.append(sample_token_id) + samples_tokenized.append(sample_token) + + #append pads + #max length is number of different temp used* max token len PLUS the concat token + # between two secondary structures represented at two diff temperatures + self.second_input_token_len = self.tokens_len + for seq_idx, token in enumerate(sample_token_ids): + sample_len = len(token) + sample_token_ids[seq_idx].extend( + [self.second_input_tokens_ids_dict["pad"] for _ in range(int(self.second_input_token_len - sample_len))] + ) + samples_tokenized[seq_idx].extend( + ["pad" for _ in range(int(self.second_input_token_len - sample_len))] + ) + sample_token_ids[seq_idx] = np.array(sample_token_ids[seq_idx]) + samples_tokenized[seq_idx] = np.array(samples_tokenized[seq_idx]) + # save vocab + return (np.array(samples_tokenized), np.array(sample_token_ids)) + + def set_class_attr(self): + #set seq,struct and exp and labels + self.seq = self.seqs_dot_bracket_labels["Sequences"] + if 'struct' in self.model_input: + self.struct = self.seqs_dot_bracket_labels["Secondary"] + + self.labels = self.seqs_dot_bracket_labels['Labels'] + + def prepare_multi_idx_pd(self,num_coln,pd_name,pd_value): + iterables = [[pd_name], np.arange(num_coln)] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + return pd.DataFrame(columns=index, data=pd_value) + + def phase_sequence(self,sample_token_ids): + phase0 = sample_token_ids[:,::2] + phase1 = sample_token_ids[:,1::2] + #in case max_length is an odd number phase 0 will be 1 entry larger than phase 1 @ dim=1 + if phase0.shape!= phase1.shape: + phase1 = np.concatenate([phase1,np.zeros(phase1.shape[0])[...,np.newaxis]],axis=1) + sample_token_ids = phase0 + + return sample_token_ids,phase1 + + def custom_roll(self,arr, n_shifts_per_row): + ''' + shifts each row of a numpy array according to n_shifts_per_row + ''' + m = np.asarray(n_shifts_per_row) + arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() #need `copy` + strd_0, strd_1 = arr_roll.strides + n = arr.shape[1] + result = as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1)) + + return result[np.arange(arr.shape[0]), (n-m)%n] + + def save_token_dicts(self): + #save token dicts + save(data = self.second_input_tokens_ids_dict,path = os.getcwd()+'/second_input_tokens_ids_dict') + save(data = self.seq_tokens_ids_dict,path = os.getcwd()+'/seq_tokens_ids_dict') + + + def get_tokenized_data(self,inference:bool=False): + #tokenize sequences + samples_tokenized,sample_token_ids = self.tokenize_samples(self.window,self.seq,inference) + + logger.info(f'Vocab size for primary sequences: {len(self.seq_tokens_ids_dict.keys())}') + logger.info(f'Vocab size for secondary structure: {len(self.second_input_tokens_ids_dict.keys())}') + logger.info(f'Number of sequences used for tokenization: {samples_tokenized.shape[0]}') + + #tokenize struct if used + if "comp" in self.model_input: + #get compliment of self.seq + self.seq_comp = [] + for feature in self.seq: + feature = feature.replace('A','%temp%').replace('T','A')\ + .replace('C','%temp2%').replace('G','C')\ + .replace('%temp%','T').replace('%temp2%','G') + self.seq_comp.append(feature) + #store seq_tokens_ids_dict + self.seq_tokens_ids_dict_temp = self.seq_tokens_ids_dict + self.seq_tokens_ids_dict = None + #tokenize compliment + _,seq_comp_str_token_ids = self.tokenize_samples(self.window,self.seq_comp,inference) + sec_input_value = seq_comp_str_token_ids + #store second input seq_tokens_ids_dict + self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict + self.seq_tokens_ids_dict = self.seq_tokens_ids_dict_temp + + + #tokenize struct if used + if "struct" in self.model_input: + _,sec_str_token_ids = self.tokenize_secondary_structure(self.window,self.struct,inference) + sec_input_value = sec_str_token_ids + + + #add seq-seq if used + if "seq-seq" in self.model_input: + sample_token_ids,sec_input_value = self.phase_sequence(sample_token_ids) + self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict + + #in case of baseline or only "seq", the second input is dummy + #TODO:: refactor transforna to accept models with a single input (baseline and "seq") + # without occupying unnecessary resources + if "seq-rev" in self.model_input or "baseline" in self.model_input or self.model_input == 'seq': + sample_token_ids_rev = sample_token_ids[:,::-1] + n_zeros = np.count_nonzero(sample_token_ids_rev==0, axis=1) + sec_input_value = self.custom_roll(sample_token_ids_rev, -n_zeros) + self.second_input_tokens_ids_dict = self.seq_tokens_ids_dict + + + + + seqs_length = list(sum(sample_token_ids.T !=0)) + + labels_df = self.prepare_multi_idx_pd(1,"Labels",self.labels.values) + tokens_id_df = self.prepare_multi_idx_pd(sample_token_ids.shape[1],"tokens_id",sample_token_ids) + tokens_df = self.prepare_multi_idx_pd(samples_tokenized.shape[1],"tokens",samples_tokenized) + sec_input_df = self.prepare_multi_idx_pd(sec_input_value.shape[1],'second_input',sec_input_value) + seqs_length_df = self.prepare_multi_idx_pd(1,"seqs_length",seqs_length) + + all_df = labels_df.join(tokens_df).join(tokens_id_df).join(sec_input_df).join(seqs_length_df) + + #save token dicts + self.save_token_dicts() + return all_df \ No newline at end of file diff --git a/transforna/src/processing/splitter.py b/transforna/src/processing/splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..564e415842c07534379d7e07a9364916dc9167bb --- /dev/null +++ b/transforna/src/processing/splitter.py @@ -0,0 +1,222 @@ +import logging +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder +from sklearn.utils.class_weight import (compute_class_weight, + compute_sample_weight) +from skorch.dataset import Dataset +from skorch.helper import predefined_split + +from ..utils.energy import fold_sequences +from ..utils.file import load, save +from ..utils.utils import (revert_seq_tokenization, + update_config_with_dataset_params_benchmark, + update_config_with_dataset_params_tcga) +from anndata import AnnData + +logger = logging.getLogger(__name__) +class DataSplitter: + def __init__(self,tokenizer,configs): + self.tokenizer = tokenizer + self.configs = configs + self.seed = configs.seed + self.trained_on = configs.trained_on + self.device = configs["train_config"].device + self.splits_df_dict = {} + self.min_num_samples_per_class = 10 + + def convert_to_tensor(self,in_arr,convert_type): + tensor_dtype = torch.long if convert_type == int else torch.float + return torch.tensor( + np.array(in_arr, dtype=convert_type), + dtype=tensor_dtype, + ).to(device=self.device) + + def get_features_per_split(self): + model_input_cols = ['tokens_id','second_input','seqs_length'] + features_dict = {} + for split_df in self.splits_df_dict.keys(): + split_data = self.convert_to_tensor(self.splits_df_dict[split_df][model_input_cols].values,convert_type=float) + split = '_'.join(split_df.split('_')[:-1]) + features_dict[f'{split}_data'] = split_data + + return features_dict + + def append_sample_weights(self,splits_features_dict): + + for split_df in self.splits_df_dict.keys(): + if split_df in ['train_df','valid_df','test_df']: + split_weights = self.convert_to_tensor(compute_sample_weight('balanced',self.splits_df_dict[split_df]['Labels'][0]),convert_type=float) + else: + split_weights = self.convert_to_tensor(np.ones(self.splits_df_dict[split_df].shape[0]),convert_type=float) + split = '_'.join(split_df.split('_')[:-1]) + splits_features_dict[f'{split}_data'] = torch.cat([splits_features_dict[f'{split}_data'],split_weights[:,None]],dim=1) + + return + + def get_labels_per_split(self): + #encode labels + enc = LabelEncoder() + enc.fit(self.splits_df_dict["train_df"]['Labels']) + #save mapping dict to config + self.configs["model_config"].class_mappings = enc.classes_.tolist() + + labels_dict = {} + labels_numeric_dict = {} + for split_df in self.splits_df_dict.keys(): + split = '_'.join(split_df.split('_')[:-1]) + + split_labels = self.splits_df_dict[split_df]['Labels'] + if split_df in ['train_df','valid_df','test_df']: + split_labels_numeric = self.convert_to_tensor(enc.transform(split_labels), convert_type=int) + else: + split_labels_numeric = self.convert_to_tensor(np.zeros((split_labels.shape[0])), convert_type=int) + + labels_dict[f'{split}_labels'] = split_labels + labels_numeric_dict[f'{split}_labels_numeric'] = split_labels_numeric + + #compute class weight + class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(labels_dict['train_labels']),y=labels_dict['train_labels'][0].values) + + #omegaconfig does not support float64 as datatype so conversion to str is done + # and reconversion is done in criterion + self.configs['model_config'].class_weights = [str(x) for x in list(class_weights)] + + + return labels_dict | labels_numeric_dict + + + def get_seqs_per_split(self): + rna_seq_dict = {} + for split_df in self.splits_df_dict.keys(): + split = '_'.join(split_df.split('_')[:-1]) + rna_seq_dict[f'{split}_rna_seq'] = revert_seq_tokenization(self.splits_df_dict[split_df]["tokens"],self.configs) + + return rna_seq_dict + + def duplicate_fewer_classes(self,df): + #get quantity of each class and append it as a column + df["Quantity",'0'] = df["Labels"].groupby([0])[0].transform("count") + frequent_samples_df = df[df["Quantity",'0'] >= self.min_num_samples_per_class].reset_index(drop=True) + fewer_samples_df = df[df["Quantity",'0'] < self.min_num_samples_per_class].reset_index(drop=True) + unique_fewer_samples_df = fewer_samples_df.drop_duplicates(subset=[('Labels',0)], keep="last") + unique_fewer_samples_df['Quantity','0'] -= self.min_num_samples_per_class + unique_fewer_samples_df['Quantity','0'] = unique_fewer_samples_df['Quantity','0'].abs() + repeated_fewer_samples_df = unique_fewer_samples_df.loc[unique_fewer_samples_df.index.repeat(unique_fewer_samples_df.Quantity['0'])] + repeated_fewer_samples_df = repeated_fewer_samples_df.reset_index(drop=True) + df = frequent_samples_df.append(repeated_fewer_samples_df).append(fewer_samples_df).reset_index(drop=True) + df.drop(columns = ['Quantity'],inplace=True) + return df + + def remove_fewer_samples(self,data_df): + if 'sub_class' in self.configs['model_config']['clf_target']: + counts = data_df['Labels'].value_counts() + fewer_class_ids = counts[counts < self.min_num_samples_per_class].index + fewer_class_labels = [i[0] for i in fewer_class_ids] + elif 'major_class' in self.configs['model_config']['clf_target']: + #insure that major classes are the same as the one used when training for sub_class + #this is done for performance comparisons to be valid + #otherwise major class models would be trained on more major classes than sub_class models + tcga_df = load(self.configs['train_config'].dataset_path_train) + #only keep hico + tcga_df = tcga_df[tcga_df['hico'] == True] + if isinstance(tcga_df,AnnData): + tcga_df = tcga_df.var + #get subclass_name with samples higher than self.min_num_samples_per_class + counts = tcga_df['subclass_name'].value_counts() + all_subclasses = tcga_df['subclass_name'].unique() + selected_subclasses = counts[counts >= self.min_num_samples_per_class].index + #convert subclass_name to major_class + subclass_to_major_class_dict = load(self.configs['train_config'].mapping_dict_path) + all_major_classes = list(set([subclass_to_major_class_dict[sub_class] for sub_class in all_subclasses])) + selected_major_classes = list(set([subclass_to_major_class_dict[sub_class] for sub_class in selected_subclasses])) + fewer_class_labels = list(set(all_major_classes) - set(selected_major_classes)) + + #remove samples with major_class not in major_classes + fewer_samples_per_class_df = data_df.loc[data_df['Labels'].isin(fewer_class_labels).values, :] + fewer_ids = data_df.index.isin(fewer_samples_per_class_df.index) + data_df = data_df[~fewer_ids] + return fewer_samples_per_class_df,data_df + + def split_tcga(self,data_df): + #remove artificial_affix + artificial_df = data_df.loc[data_df['Labels'][0].isin(['random','recombined','artificial_affix'])] + art_ids = data_df.index.isin(artificial_df.index) + data_df = data_df[~art_ids] + data_df = data_df.reset_index(drop=True) + + #remove no annotations + no_annotaton_df = data_df.loc[data_df['Labels'].isnull().values] + n_a_ids = data_df.index.isin(no_annotaton_df.index) + data_df = data_df[~n_a_ids].reset_index(drop=True) + no_annotaton_df = no_annotaton_df.reset_index(drop=True) + + if self.trained_on == 'full': + #duplication is done to ensure as other wise train_test_split will fail + data_df = self.duplicate_fewer_classes(data_df) + ood_dict = {} + else: + ood_df,data_df = self.remove_fewer_samples(data_df) + ood_dict = {"ood_df":ood_df} + #split data + train_df,valid_test_df = train_test_split(data_df,stratify=data_df["Labels"],train_size=0.8,random_state=self.seed) + if self.trained_on == 'id': + valid_df,test_df = train_test_split(valid_test_df,stratify=valid_test_df["Labels"],train_size=0.5,random_state=self.seed) + else: + #we need to use all n sequences in the training set, however, unseen samples should be gathered for training novelty prediction, + #otherwise NLD for test would be zero + #remove one sample from each class to test_df + test_df = valid_test_df.drop_duplicates(subset=[('Labels',0)], keep="last") + test_ids = valid_test_df.index.isin(test_df.index) + valid_df = valid_test_df[~test_ids].reset_index(drop=True) + train_df = train_df.append(valid_df).reset_index(drop=True) + + self.splits_df_dict = {"train_df":train_df,"valid_df":valid_df,"test_df":test_df,"artificial_df":artificial_df,"no_annotation_df":no_annotaton_df} | ood_dict + + def prepare_data_tcga(self): + """ + This function recieves tokenizer and prepares the data in a format suitable for training + It also set default parameters in the config that cannot be known until preprocessing step + is done. + """ + all_data_df = self.tokenizer.get_tokenized_data() + + #split data + self.split_tcga(all_data_df) + + num_samples = self.splits_df_dict['train_df'].shape[0] + num_classes = len(self.splits_df_dict['train_df'].Labels.value_counts()[self.splits_df_dict['train_df'].Labels.value_counts()>0]) + #log + logger.info(f'Training with {num_classes} classes and {num_samples} samples') + + #get features, labels, and seqs per split + splits_features_dict = self.get_features_per_split() + self.append_sample_weights(splits_features_dict) + splits_labels_dict = self.get_labels_per_split() + splits_seqs_dict = self.get_seqs_per_split() + + + #prepare validation set for skorch + valid_ds = Dataset(splits_features_dict["valid_data"],splits_labels_dict["valid_labels_numeric"]) + valid_ds = predefined_split(valid_ds) + + #combine all dicts + all_data = splits_features_dict | splits_labels_dict | splits_seqs_dict | \ + {"valid_ds":valid_ds} + + ###update self.configs + update_config_with_dataset_params_tcga(self.tokenizer,all_data_df,self.configs) + self.configs["model_config"].num_classes = len(all_data['train_labels'][0].unique()) + self.configs["train_config"].batch_per_epoch = int(all_data["train_data"].shape[0]\ + /self.configs["train_config"].batch_size) + + return all_data + + + + diff --git a/transforna/src/readme.md b/transforna/src/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..6cfd9fea5038d98c35a4ddac7dacc8faae01dbc6 --- /dev/null +++ b/transforna/src/readme.md @@ -0,0 +1,22 @@ +This is the transforna package which contains the following modules: + +- `train` is the entry point where data preparation, training and results logging is executed. + +- `processing` contains all classes used for data augmentation, tokenization and splitting. + +- `model` contains the skorch model `skorchWrapper` that wraps the torch model described in model components + +- `callbacks` contains the learning rate scheduler, loss function and the metrics used to evaluate the model. + +- `score` compute the balanced accuracy of the classification task -major or sub-class- for each of the splits with known labels(train/valid/test). + +- `novelty_prediction` contains two novelty metrics; entropy based(obsolete) and Normalized Levenstein Distance, NLD based (current). + +- `inference` contains all inference functionalities. check `transforna/scripts/test_inference_api.py` for how-to-use. + +A schematic of the TransfoRNA Architecture: + + +![TransfoRNA Architecture](https://github.com/gitHBDX/TransfoRNA/assets/82571392/a1bfbb1e-32c9-4faf-96ae-46727c27e321) + +Model evauation image [source](https://medium.com/@sachinsoni600517/model-evaluation-techniques-in-machine-learning-47ae9fb0ad33) diff --git a/transforna/src/score/__init__.py b/transforna/src/score/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8dbd7a82192da5cdf247971562c1db02fc4904b --- /dev/null +++ b/transforna/src/score/__init__.py @@ -0,0 +1 @@ +from .score import * diff --git a/transforna/src/score/score.py b/transforna/src/score/score.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdffa7ed0d57504a426344dc03e6552ee9899c0 --- /dev/null +++ b/transforna/src/score/score.py @@ -0,0 +1,321 @@ +import logging +import os +import pickle +from typing import Dict + +import numpy as np +import pandas as pd +import skorch +import torch +import torch.nn as nn +from sklearn.metrics import confusion_matrix + +from ..utils.file import save + +logger = logging.getLogger(__name__) + +def load_pkl(name ): + with open(name + '.pkl', 'rb') as f: + return pickle.load(f) + +def infere_additional_test_data(net,data): + ''' + The premirna task has an additional dataset containing premirna from different species + This function computes the accuracy score on this additional test set + All samples in the additional test data are precurosr mirnas + ''' + for dataset_idx in range(len(data)): + predictions = net.predict(data[dataset_idx]) + correct = sum(torch.max(predictions,1).indices) + total = len(torch.max(predictions,1).indices) + logger.info(f'The prediction on the {dataset_idx} dataset is {correct} out of {total}') + +def get_rna_seqs(seq, model_config): + rna_seqs = [] + if model_config.tokenizer == "no_overlap": + for _, row in seq.iterrows(): + rna_seqs.append("".join(x for x in row if x != "pad")) + else: + rna_seqs_overlap = [] + for _, row in seq.iterrows(): + # remove the paddings + rna_seqs_overlap.append([x for x in row if x != "pad"]) + # join the beg of each char in rna_seqs_overlap + rna_seqs.append("".join(x[0] for x in rna_seqs_overlap[-1])) + # append the last token w/o its first char + rna_seqs[-1] = "".join(rna_seqs[-1] + rna_seqs_overlap[-1][-1][1:]) + + return rna_seqs + +def save_embedds(net,path:str,rna_seq,split:str='train',labels:pd.DataFrame=None,model_config = None,logits=None): + #reconstruct seqs + # join sequence and remove pads + rna_seqs = get_rna_seqs(rna_seq, model_config) + + # create pandas dataframe of sequences + iterables = [["RNA Sequences"], np.arange(1, dtype=int)] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + rna_seqs_df = pd.DataFrame(columns=index, data=np.vstack(rna_seqs)) + + data=np.vstack(net.gene_embedds) + # create pandas dataframe for token ids of sequences + iterables = [["RNA Embedds"], np.arange((data.shape[1]), dtype=int)] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + gene_embedd_df = pd.DataFrame(columns=index, data=data) + + if 'baseline' not in model_config.model_input: + data = np.vstack(net.second_input_embedds) + iterables = [["SI Embedds"], np.arange(data.shape[1], dtype=int)] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + exp_embedd_df = pd.DataFrame(columns=index, data=data) + else: + exp_embedd_df = [] + + iterables = [["Labels"], np.arange(1, dtype=int)] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + labels_df = pd.DataFrame(columns=index, data=labels.values) + + if logits: + iterables = [["Logits"], model_config.class_mappings] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + logits_df = pd.DataFrame(columns=index, data=np.array(logits)) + + final_csv = rna_seqs_df.join(gene_embedd_df).join(exp_embedd_df).join(labels_df).join(logits_df) + else: + final_csv = rna_seqs_df.join(gene_embedd_df).join(exp_embedd_df).join(labels_df) + + save(data=final_csv,path =f'{path}{split}_embedds') + + +def infer_from_model(net,split_data:torch.Tensor): + batch_size = 100 + predicted_labels_str = [] + soft = nn.Softmax() + logits = [] + attn_scores_first_list = [] + attn_scores_second_list = [] + #this dict will be used to convert between neumeric predictions and string labels + labels_mapping_dict = net.labels_mapping_dict + #switch labels and str_labels + labels_mapping_dict = {y:x for x,y in labels_mapping_dict.items()} + for idx,batch in enumerate(torch.split(split_data, batch_size)): + predictions = net.predict(batch) + attn_scores_first,attn_scores_second = net.get_attention_scores(batch) + predictions = predictions[:,:-1] + + max_ids_tensor = torch.max(predictions,1).indices + if max_ids_tensor.is_cuda: + max_ids_tensor = max_ids_tensor.cpu().numpy() + predicted_labels_str.extend([labels_mapping_dict[x] for x in max_ids_tensor.tolist()]) + + logits.extend(soft(predictions).detach().cpu().numpy()) + + attn_scores_first_list.extend(attn_scores_first) + if attn_scores_second is not None: + attn_scores_second_list.extend(attn_scores_second) + + return predicted_labels_str,logits,attn_scores_first_list,attn_scores_second_list + +def get_split_score(net,split_data:torch.Tensor,split_labels:torch.Tensor,split:str,scoring_function:Dict,task:str=None,log_split_str_labels:bool=False,mirna_flag:bool = None): + split_acc = [] + batch_size = 100 + predicted_labels_str = [] + true_labels_str = [] + #this dict will be used to convert between neumeric predictions and string labels + labels_mapping_dict = net.labels_mapping_dict + #switch labels and str_labels + labels_mapping_dict = {y:x for x,y in labels_mapping_dict.items()} + for idx,batch in enumerate(torch.split(split_data, batch_size)): + predictions = net.predict(batch) + if split_labels is not None: + true_labels = torch.split(split_labels,batch_size)[idx] + if mirna_flag is not None: + batch_score = scoring_function(true_labels.numpy(), predictions,task=task,mirna_flag=mirna_flag) + batch_score /= sum(true_labels.numpy().squeeze() == mirna_flag) + else: + batch_score = scoring_function(true_labels.numpy(), predictions,task=task) + split_acc.append(batch_score) + + if log_split_str_labels: + #save true labels + if split_labels is not None: + true_labels_str.extend([labels_mapping_dict[x[0]] for x in true_labels.numpy().tolist()]) + predicted_labels_str.extend([labels_mapping_dict[x] for x in torch.max(predictions[:,:-1],1).indices.cpu().numpy().tolist()]) + + if log_split_str_labels: + #save all true and predicted labels to compute metrics on + if split_labels is not None: + with open(f"true_labels_{split}.pkl", "wb") as fp: + pickle.dump(true_labels_str, fp) + + with open(f"predicted_labels_{split}.pkl", "wb") as fp: + pickle.dump(predicted_labels_str, fp) + + + if split_labels is not None: + split_score = sum(split_acc)/len(split_acc) + if mirna_flag is not None: + logger.info(f"{split} accuracy score is {split_score} for mirna: {mirna_flag}") + else: + #only for inference + split_score = None + + logger.info(f"{split} accuracy score is {split_score}") + + return split_score,predicted_labels_str + +def generate_embedding(net, path:str,accuracy_sore,data, data_seq,labels,labels_numeric,split,model_config=None,train_config=None,log_embedds:bool=False): + + predictions_per_split = [] + accuracy = [] + logits = [] + weights_per_batch = [] + data = torch.cat((data.T,labels_numeric.unsqueeze(1).T)).T + for batch in torch.split(data, train_config.batch_size): + weights_per_batch.append(batch.shape[0]) + predictions = net.predict(batch[:,:-1]) + soft = nn.Softmax(dim=1) + logits.extend(list(soft(predictions[:,:-1]).detach().cpu().tolist())) + + accuracy.append(accuracy_sore(batch[:,-1], predictions)) + + #drop sample weights + predictions = predictions[:,:-1] + + predictions = torch.argmax(predictions,axis=1) + predictions_per_split.extend(predictions.tolist()) + + if split == 'test': + matrix = confusion_matrix(labels_numeric.tolist(), predictions_per_split) + #get the worst predicted classes + worst_predicted_classes = np.argsort(matrix.diagonal())[:40] + best_predicted_classes = np.argsort(matrix.diagonal())[-40:] + #first get the mapping dict from labels_numeric tensor and labels containing string labels + mapping_dict = {} + for idx,label in enumerate(labels_numeric.tolist()): + mapping_dict[label] = labels.values[idx][0] + #convert worst_predicted_classes to string labels + worst_predicted_classes = [mapping_dict[x] for x in worst_predicted_classes] + #save worst predicted classes as csv + pd.DataFrame(worst_predicted_classes).to_csv(f"{path}worst_predicted_classes.csv") + + #check how many files in path start with confusion_matrix + num_confusion_matrix = len([name for name in os.listdir(path) if name.startswith("confusion_matrix")]) + #save confusion matrix + cf = pd.DataFrame(matrix) + #rename cf columns to be the labels by first ordering the mapping dict by the keys + cf.columns = [mapping_dict[x] for x in sorted(mapping_dict.keys())] + cf.index = cf.columns + cf.to_csv(f"{path}confusion_matrix_{num_confusion_matrix}.csv") + + + score_avg = 0 + if split in ['train','valid','test']: + score_avg = np.average(accuracy,weights = weights_per_batch) + logger.info(f"total accuracy score on {split} is {np.round(score_avg,4)}") + + + if log_embedds: + logger.debug(f"logging embedds for {split} set") + save_embedds(net,path,data_seq,split,labels,model_config,logits) + + return score_avg + + + +def compute_score_tcga( + net, all_data, path,cfg:Dict +): + task = cfg['task'] + net.load_params(f_params=f'{path}/ckpt/model_params_{task}.pt') + net.save_embedding = True + + #create path for embedds and confusion matrix + embedds_path = path+"/embedds/" + if not os.path.exists(embedds_path): + os.mkdir(embedds_path) + + #get scoring function + for cb in net.callbacks: + if type(cb) == skorch.callbacks.scoring.BatchScoring: + scoring_function = cb.scoring._score_func + break + + splits = ['train','valid','test','ood','no_annotation','artificial'] + + test_score = 0 + #log all splits + for split in splits: + # reset tensors + net.gene_embedds = [] + net.second_input_embedds = [] + try: + score = generate_embedding(net,embedds_path,scoring_function,all_data[f"{split}_data"],all_data[f"{split}_rna_seq"],\ + all_data[f"{split}_labels"],all_data[f"{split}_labels_numeric"],f'{split}',\ + cfg['model_config'],cfg['train_config'],cfg['log_embedds']) + if split == 'test': + test_score = score + except: + trained_on = cfg['trained_on'] + logger.info(f'Skipping {split} split, as split does not exist for models trained on {trained_on}!') + + + + return test_score + + + + + +def compute_score_benchmark( + net, path,all_data,scoring_function:Dict, cfg:Dict +): + task = cfg['task'] + net.load_params(f_params=f'{path}/ckpt/model_params_{task}.pt') + net.save_embedding = True + # reset tensors + net.gene_embedds = [] + net.second_input_embedds = [] + + if task == 'premirna': + get_split_score(net,all_data["train_data"],all_data["train_labels_numeric"],'train',scoring_function,task,mirna_flag = 0) + get_split_score(net,all_data["train_data"],all_data["train_labels_numeric"],'train',scoring_function,task,mirna_flag = 1) + else: + get_split_score(net,all_data["train_data"],all_data["train_labels_numeric"],'train',scoring_function,task) + + embedds_path = path+"/embedds/" + if not os.path.exists(embedds_path): + os.mkdir(embedds_path) + if cfg['log_embedds']: + torch.save(torch.vstack(net.gene_embedds), embedds_path+"train_gene_embedds.pt") + torch.save(torch.vstack(net.second_input_embedds), embedds_path+"train_gene_exp_embedds.pt") + all_data["train_rna_seq"].to_pickle(embedds_path+"train_rna_seq.pkl") + + # reset tensors + net.gene_embedds = [] + net.second_input_embedds = [] + if task == 'premirna': + test_score_0,_ = get_split_score(net,all_data["test_data"],all_data["test_labels_numeric"],'test',scoring_function,task,mirna_flag = 0) + test_score_1,_ = get_split_score(net,all_data["test_data"],all_data["test_labels_numeric"],'test',scoring_function,task,mirna_flag = 1) + test_score = (test_score_0+test_score_1)/2 + else: + test_score,_ = get_split_score(net,all_data["test_data"],all_data["test_labels_numeric"],'test',scoring_function,task) + + if cfg['log_embedds']: + torch.save(torch.vstack(net.gene_embedds), embedds_path+"test_gene_embedds.pt") + torch.save(torch.vstack(net.second_input_embedds), embedds_path+"test_gene_exp_embedds.pt") + all_data["test_rna_seq"].to_pickle(embedds_path+"test_rna_seq.pkl") + return test_score + + + +def infer_testset(net,cfg,all_data,accuracy_score): + if cfg["task"] == 'premirna': + split_score,predicted_labels_str = get_split_score(net,all_data["test_data"],all_data["test_labels"],'test',accuracy_score,cfg["task"],mirna_flag = 0) + split_score,predicted_labels_str = get_split_score(net,all_data["test_data"],all_data["test_labels"],'test',accuracy_score,cfg["task"],mirna_flag = 1) + else: + split_score,predicted_labels_str = get_split_score(net,all_data["test_data"],all_data["test_labels_numeric"],'test',\ + accuracy_score,cfg["task"],log_split_str_labels = True) + #only for premirna + if "additional_testset" in all_data: + infere_additional_test_data(net,all_data["additional_testset"]) \ No newline at end of file diff --git a/transforna/src/train/__init__.py b/transforna/src/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b799fb9405fb783f4e24482faebb109c7091400 --- /dev/null +++ b/transforna/src/train/__init__.py @@ -0,0 +1 @@ +from .train import * diff --git a/transforna/src/train/train.py b/transforna/src/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b7fea08846154a3a26caa993aea77943839aa7 --- /dev/null +++ b/transforna/src/train/train.py @@ -0,0 +1,90 @@ +import logging +from typing import Dict + +from anndata import AnnData +from omegaconf import DictConfig, OmegaConf + +from ..callbacks.metrics import accuracy_score +from ..novelty_prediction.id_vs_ood_entropy_clf import compute_entropies +from ..novelty_prediction.id_vs_ood_nld_clf import compute_nlds +from ..processing.augmentation import DataAugmenter +from ..processing.seq_tokenizer import SeqTokenizer +from ..processing.splitter import * +from ..processing.splitter import DataSplitter +from ..score.score import (compute_score_benchmark, compute_score_tcga, + infere_additional_test_data) +from ..utils.file import load, save +from ..utils.utils import (instantiate_predictor, prepare_data_benchmark, + set_seed_and_device, sync_skorch_with_config) + +logger = logging.getLogger(__name__) + +def compute_cv(cfg:DictConfig,path:str,output_dir:str): + + summary_pd = pd.DataFrame(index=np.arange(cfg["num_replicates"]),columns = ['B. Acc','Dur']) + for seed_no in range(cfg["num_replicates"]): + logger.info(f"Currently training replicate {seed_no}") + cfg["seed"] = seed_no + test_score,net = train(cfg,path=path,output_dir=output_dir) + convrg_epoch = np.where(net.history[:,'val_acc_best'])[0][-1] + convrg_dur = sum(net.history[:,'dur'][:convrg_epoch+1]) + summary_pd.iloc[seed_no] = [test_score,convrg_dur] + + save(path=path+'/summary_pd',data=summary_pd) + + return + +def train(cfg:Dict= None,path:str = None,output_dir:str = None): + if cfg['tensorboard']: + from ..callbacks.tbWriter import writer + #set seed + set_seed_and_device(cfg["seed"],cfg["device_number"]) + + dataset = load(cfg["train_config"].dataset_path_train) + + if isinstance(dataset,AnnData): + dataset = dataset.var + else: + dataset.set_index('sequence',inplace=True) + + #instantiate dataset class + + if cfg["task"] in ["premirna","sncrna"]: + tokenizer = SeqTokenizer(dataset,cfg) + test_ad = load(cfg["train_config"].dataset_path_test) + #prepare data for training and inference + all_data = prepare_data_benchmark(tokenizer,test_ad,cfg) + else: + df = DataAugmenter(dataset,cfg).get_augmented_df() + tokenizer = SeqTokenizer(df,cfg) + all_data = DataSplitter(tokenizer,cfg).prepare_data_tcga() + + #sync skorch config with params in train and model config + sync_skorch_with_config(cfg["model"]["skorch_model"],cfg) + + # instantiate skorch model + net = instantiate_predictor(cfg["model"]["skorch_model"], cfg,path) + + #train + #if train_split is none, then discard valid_ds + net.fit(all_data["train_data"],all_data["train_labels_numeric"],all_data["valid_ds"]) + + #log train and model HP to curr run dir + save(data=OmegaConf.to_container(cfg, resolve=True),path=path+'/meta/hp_settings.yaml') + + #compute scores and log embedds + if cfg['task'] == 'tcga': + test_score = compute_score_tcga(net, all_data,path,cfg) + compute_nlds(output_dir) + compute_entropies(output_dir) + else: + test_score = compute_score_benchmark(net, path,all_data,accuracy_score,cfg) + #only for premirna + if "additional_testset" in all_data: + infere_additional_test_data(net,all_data["additional_testset"]) + + + + if cfg['tensorboard']: + writer.close() + return test_score,net diff --git a/transforna/src/utils/__init__.py b/transforna/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03f6c96852cd3db2107690e642d442fc187bfa48 --- /dev/null +++ b/transforna/src/utils/__init__.py @@ -0,0 +1,4 @@ +from .energy import * +from .file import * +from .tcga_post_analysis_utils import * +from .utils import * diff --git a/transforna/src/utils/energy.py b/transforna/src/utils/energy.py new file mode 100644 index 0000000000000000000000000000000000000000..48ead79e197f80af0521a13099907a99f126287b --- /dev/null +++ b/transforna/src/utils/energy.py @@ -0,0 +1,58 @@ + +import functools +import typing as ty + +import pandas as pd +import RNA + + +@functools.lru_cache() +def duplex_energy(s1: str, s2: str) -> float: + return RNA.duplexfold(s1, s2).energy + + +@functools.lru_cache() +def folded_sequence(sequence, model_details): + folder = RNA.fold_compound(sequence, model_details) + dot_bracket, mfe = folder.mfe() + return dot_bracket, mfe + + +def fold_sequences( + sequences: ty.Iterable[str], temperature: float = 37.0, +) -> pd.DataFrame: + + md = RNA.md() + md.temperature = temperature + + seq2structure_map = { + "sequence": [], + f"structure_{int(temperature)}": [], + f"mfe_{int(temperature)}": [], + } + + for sequence in sequences: + dot_bracket, mfe = folded_sequence(sequence, md) + seq2structure_map["sequence"].append(sequence) + seq2structure_map[f"structure_{int(temperature)}"].append(dot_bracket) + seq2structure_map[f"mfe_{int(temperature)}"].append(mfe) + + return pd.DataFrame(seq2structure_map).set_index("sequence") + +def fraction(seq: str, nucleoids: str) -> float: + """Computes the fraction of the sequence string that is the set of nucleoids + given. + + Parameters + ---------- + seq : str + The sequence string + nucleoids : str + The list of nucleoids to compute the fraction for. + + Returns + ------- + float + The fraction + """ + return sum([seq.count(n) for n in nucleoids]) / len(seq) \ No newline at end of file diff --git a/transforna/src/utils/file.py b/transforna/src/utils/file.py new file mode 100644 index 0000000000000000000000000000000000000000..ade833d12fb86a621cfbfc93e0fd26f9620a8416 --- /dev/null +++ b/transforna/src/utils/file.py @@ -0,0 +1,339 @@ +import json +import logging +import os +import pickle +from pathlib import Path +from typing import Any, List + +import anndata +import dill +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import yaml +from anndata import AnnData +from Bio.SeqIO.FastaIO import SimpleFastaParser + +logger = logging.getLogger(__name__) + + +def create_dirs(paths:List): + for path in paths: + if not os.path.exists(path): + os.mkdir(path) + +def save(path: Path, data: object, ignore_ext: bool = False) -> Path: + """Saves data to this path. Extension and saving function is determined from the type. + If the correct extension was already in the path its also ok. + At the moment we handle: + - pyplot figures -> .pdf + - dictionaries -> .yaml + - list -> .yaml + - numpy -> .npy + - pandas dataframes -> .tsv + - anndata -> .h5ad + - strings -> .txt + - _anything else_ -> .p (pickled with `dill`) + Parameters + ---------- + path : Path + The full path to save to + data: object + Data to save + ignore_ext : bool + Whether to ignore adding the normal expected extension + Returns + ------- + Path + The final path to the file + """ + if not isinstance(path, Path): + path = Path(path) + + # Make sure the folder exists: + path.parent.mkdir(parents=True, exist_ok=True) + + annotation_path = os.path.dirname(os.path.abspath(__file__)) + with open(annotation_path+"/tcga_anndata_groupings.yaml", 'r') as stream: + tcga_annotations = yaml.safe_load(stream) + + def make_path(p: Path, ext: str) -> Path: + """If the path doesn't end with the given extension add the extension to the path. + Parameters + ---------- + p : Path + The path + ext : str + The expected extension + Returns + ------- + Path + The fixed path + """ + if not ignore_ext and not p.name.endswith(ext): + return p.parent.joinpath(f"{p.name}{ext}") + return p + + + # PyPlot Figure + if isinstance(data, mpl.figure.Figure): + path = make_path(path, ".pdf") + data.savefig(path) + plt.close(data) + # Dict ⇒ YAML Files + elif isinstance(data, dict): + path = make_path(path, ".yaml") + with open(path, "w") as fp: + yaml.dump(data, fp) + # List ⇒ YAML Files + elif isinstance(data, list): + path = make_path(path, ".yaml") + with open(path, "w") as fp: + yaml.dump(data, fp) + # NumPy Array + elif isinstance(data, np.ndarray): + path = make_path(path, ".npy") + np.save(path, data) + # Dataframes ⇒ TSV + elif isinstance(data, pd.DataFrame): + path = make_path(path, ".tsv") + data.to_csv(path, sep="\t") + # AnnData + elif isinstance(data, anndata.AnnData): + path = make_path(path, ".h5ad") + for date_col in set(tcga_annotations['anndata']['obs']['datetime_columns']) & set(data.obs.columns): + if "datetime" in data.obs[date_col].dtype.name: + data.obs[date_col] = data.obs[date_col].dt.strftime("%Y-%m-%d") + else: + logger.info(f"Column {date_col} in obs should be a date but isnt formatted as one.") + data.write(path) + # Strings to normal files + elif isinstance(data, str): + path = make_path(path, ".txt") + with open(path, "w") as fp: + fp.write(data) + # Everything else ⇒ pickle + else: + path = make_path(path, ".p") + dill.dump(data, open(path, "wb")) + return path + + + +def _resolve_path(path: Path) -> Path: + """Given a path, will try to resolve it in multiple ways: + + 1. Is it a path to a S3 bucket? + 2. Is it a global/local file that exists? + 3. Is it path that is a prefix to a file that is unique? + + Parameters + ---------- + path : Path + The path + + Returns + ------- + Path + The global resolved file. + + Raises + ------ + FileNotFoundError + If the file doesn't exists or if there are multiple files that match the glob. + """ + if not path.name.startswith("/"): + path = path.expanduser().resolve() + + # If it exists we'll take it: + if path.exists(): + return path + + # But mostly we load files without the extension so we glob for a uniue file: + glob_name = path.name if path.name.endswith("*") else path.name + "*" + paths = list(path.parent.glob(glob_name)) + if len(paths) == 1: + return paths[0] # was unique glob + + raise FileNotFoundError( + f"Was trying to resolve path\n\t{path}*\nbut was ambigious because there are no or multiple files that fit the glob." + ) + +def _to_int_string(element: Any) -> str: + """Casts a number to a fixed formatted string that's nice categoriazebale. + + Parameters + ---------- + element : Any + The number, float or int + + Returns + ------- + str + Either the number formatted as a string or the original input if it + didn't work + """ + try: + fl = float(element) + return f"{fl:0.0f}" + except: + return element + +def cast_anndata(ad: AnnData) -> None: + """Fixes the data-type in the `.obs` and `.var` DataFrame columns of an + AnnData object. __Works in-place__. Currently does the following: + + 1.1. Enforces numerical-categorical `.obs` columns + 1.2. Makes all other `.obs` columns categoricals + 1.3. Makes date-time `.obs` columns, non-categorical pandas `datetime64` + 1.4. Enforces real strinng `.obs` columns, to be strings not categoricals + 1.5. Enforces some numerical `.obs` columns + + Configuration for which column belongs in which group is configured in + `/transforna/utils/ngs_annotations.yaml` in this repository. + + Parameters + ---------- + ad : AnnData + The AnnData object + """ + # 1. Fix obs-annotation dtypes + + # 1.1. Force numerical looking columns to be actual categorical variables + annotation_path = os.path.dirname(os.path.abspath(__file__)) + with open(annotation_path+"/tcga_anndata_groupings.yaml", 'r') as stream: + tcga_annotations = yaml.safe_load(stream) + numerical_categorical_columns: List[str] = set(tcga_annotations['anndata']['obs']['numerical_categorical_columns']) & set( + ad.obs.columns + ) + for column in numerical_categorical_columns: + ad.obs[column] = ad.obs[column].apply(_to_int_string).astype("U").astype("category") + + # 1.2. Forces string and mixed columns to be categoricals + ad.strings_to_categoricals() + + # 1.3. DateTime, parse dates from string + datetime_columns: List[str] = set(tcga_annotations['anndata']['obs']['datetime_columns']) & set(ad.obs.columns) + for column in datetime_columns: + try: + ad.obs[column] = pd.to_datetime(ad.obs[column]).astype("datetime64[ns]") + except ValueError as e: + warning( + f"""to_datetime error (parsing "unparseable"):\n {e}\nColumn + {column} will be set as string not as datetime.""" + ) + ad.obs[column] = ad.obs[column].astype("string") + + # 1.4. Make _real_ string columns to force to be string, reversing step 1.2. + # These are columns that contain acutal text, something like an description + # or also IDs, which are identical not categories. + string_columns: List[str] = set(tcga_annotations['anndata']['obs']['string_columns']) & set(ad.obs.columns) + for column in string_columns: + ad.obs[column] = ad.obs[column].astype("string") + + # 1.5. Force numerical columns to be numerical, this is necesary with some + # invalid inputs or NaNs + numerical_columns: List[str] = set(tcga_annotations['anndata']['obs']['numerical_columns']) & set(ad.obs.columns) + for column in numerical_columns: + ad.obs[column] = pd.to_numeric(ad.obs[column], errors="coerce") + + # 2. Fix var-annotation dtypes + + # 2.1. Enforce boolean columns to be real python bools, normally NaNs become + # True here, which we change to False. + boolean_columns: List[str] = set(tcga_annotations['anndata']['var']['boolean_columns']) & set(ad.var.columns) + for column in boolean_columns: + ad.var[column].fillna(False, inplace=True) + ad.var[column] = ad.var[column].astype(bool) + + +def load(path: str, ext: str = None, **kwargs): + """Loads the given filepath. + + This will use the extension of the filename to determine what to use for + reading (if not overwritten). Most common use-case: + + At the moment we handle: + + - pickled objects (.p) + - numpy objects (.npy) + - dataframes (.csv, .tsv) + - json files (.json) + - yaml files (.yaml) + - anndata files (.h5ad) + - excel files (.xlsx) + - text (.txt) + + Parameters + ---------- + path : str + The file-name of the cached file, without extension. (Or path) + The file-name can be a glob match e.g. `/data/something/LC__*__21.7.2.*` + which matches the everything with anything filling the stars. This only + works if there is only one match. So this is shortcut if you do not know + the full name but you know there is only one. + ext : str, optional + The extension to assume, ignoring the actual extension. E.g. loading + "tsv" for a "something.csv" file with tab-limits, by default None + + Returns + ------- + Whatever is in the saved file. + + Raises + ------ + FileNotFoundError + If a given path doesn't exist or doesn't give a unqiue file path. + NotImplementedError + Trying to load a file with an extension we do not have loading code for. + """ + path = _resolve_path(Path(path)) + + # If extension is not overwritten take the one from the path_ + if ext is None: + ext = path.suffix[1:] + + # Pickle files + if ext == "p": + return pickle.load(open(path, "rb")) + # Numpy Arrays + elif ext == "npy": + return np.load(path) + # TSV ⇒ DataFrame + elif ext == "tsv": + return pd.read_csv(path, sep="\t", **kwargs) + # CSV ⇒ DataFrame + elif ext == "csv": + return pd.read_csv(path, **kwargs) + # JSON ⇒ dict + elif ext == "json": + return json.load(open(path)) + # YAML ⇒ dict + elif ext == "yaml": + return yaml.load(open(path), Loader=yaml.SafeLoader) + # AnnData + elif ext == "h5ad": + ad = anndata.read_h5ad(path) + cast_anndata(ad) + return ad + # Excel files ⇒ DataFrame + elif ext == "xlsx": + return pd.read_excel(path, **kwargs) + # General text files ⇒ string + elif ext == "txt": + with open(path, "r") as text_file: + return text_file.read() + #fasta + elif ext == "fa": + ## load sequences + with open(path) as fasta_file: + identifiers = [] + seqs = [] + for title, sequence in SimpleFastaParser(fasta_file): + identifiers.append(title.split(None, 1)[0]) + seqs.append(sequence) + #convert sequences to dataframe + return pd.DataFrame({'Sequences':seqs}) + else: + raise NotImplementedError diff --git a/transforna/src/utils/tcga_anndata_groupings.yaml b/transforna/src/utils/tcga_anndata_groupings.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7d369e167b29b065e6f01507d3edd4820f1ad3c --- /dev/null +++ b/transforna/src/utils/tcga_anndata_groupings.yaml @@ -0,0 +1,16 @@ +anndata: + obs: + numerical_columns: + - length + # These are columns that will be parsed with pd.to_datetime + datetime_columns: + - None + # These are columns which only contain strings but should not be set as categorical! + string_columns: + - snRNA_name + # These are the columns which contain numerical values but should be treated as categorical variables + numerical_categorical_columns: + - None + var: + boolean_columns: + - spikein \ No newline at end of file diff --git a/transforna/src/utils/tcga_post_analysis_utils.py b/transforna/src/utils/tcga_post_analysis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b43ae72b4aef0cbda2ecbddc76e9fbf88659d8 --- /dev/null +++ b/transforna/src/utils/tcga_post_analysis_utils.py @@ -0,0 +1,333 @@ +import logging +import os +import pickle +from pathlib import Path +from typing import Dict, List + +import numpy as np +import pandas as pd +import scanpy as sc +from anndata import AnnData +from sklearn.neighbors import NearestNeighbors + +from .file import create_dirs, load + +logger = logging.getLogger(__name__) +class Results_Handler(): + def __init__(self,embedds_path:str,splits:List,mc_flag:bool=False,read_dataset:bool=False,create_knn_graph:bool=False,run_name:str=None,save_results:bool=False) -> None: + self.save_results = save_results + self.all_splits = ['train','valid','test','ood','artificial','no_annotation'] + if splits == ['all']: + self.splits = self.all_splits + else: + self.splits = splits + self.mc_flag = mc_flag + + #if embedds is not at the end of the embedds_path then append it + if not embedds_path.endswith('embedds'): + embedds_path = embedds_path+'/embedds' + + _,self.splits_df_dict = self.get_data(embedds_path,self.splits) + + #set column names + self.embedds_cols:List = [col for col in self.splits_df_dict[f'{splits[0]}_df'] if "Embedds" in col[0]] + self.seq_col:str = 'RNA Sequences' + self.label_col:str = 'Labels' + + #create directories + self.parent_path:str = '/'.join(embedds_path.split('/')[:-1]) + self.figures_path:str = self.parent_path+'/figures' + self.analysis_path:str = self.parent_path+'/analysis' + self.meta_path:str = self.parent_path+'/meta' + self.umaps_path:str = self.parent_path+'/umaps' + self.post_models_path:str = self.parent_path+'/post_models' + create_dirs([self.figures_path,self.analysis_path,self.post_models_path]) + + #get half of embedds cols if the model is Seq + model_name = self.get_hp_param(hp_param="model_name") + if model_name == 'seq': + self.embedds_cols = self.embedds_cols[:len(self.embedds_cols)//2] + + if not run_name: + self.run_name = self.get_hp_param(hp_param="model_input") + if type(self.run_name) == list: + self.run_name = '-'.join(self.run_name) + + ad_path = self.get_hp_param(hp_param="dataset_path_train") + if read_dataset: + self.dataset = load(ad_path) + if isinstance(self.dataset,AnnData): + self.dataset = self.dataset.var + + + self.seperate_label_from_split(split='artificial',removed_label='artificial_affix') + self.seperate_label_from_split(split='artificial',removed_label='random') + self.seperate_label_from_split(split='artificial',removed_label='recombined') + + self.sc_to_mc_mapper_dict = self.load_mc_mapping_dict() + + #get whether curr results are trained on ID or FULL + self.trained_on = self.get_hp_param(hp_param="trained_on") + #the main config of models trained on ID is not logged as for FULL + if self.trained_on == None: + self.trained_on = 'id' + + + #read train to be used for knn training and inference + train_df = self.splits_df_dict['train_df'] + + self.knn_seqs = train_df[self.seq_col].values + self.knn_labels = train_df[self.label_col].values + + #create knn model if does not exist + if create_knn_graph: + self.create_knn_model() + + def create_knn_model(self): + #get all train embedds + train_embedds = self.splits_df_dict['train_df'][self.embedds_cols].values + #linalg + train_embedds = train_embedds/np.linalg.norm(train_embedds,axis=1)[:,None] + #create knn model + self.knn_model = NearestNeighbors(n_neighbors=10,algorithm='brute',n_jobs=-1) + self.knn_model.fit(train_embedds) + #save knn model + filename = self.post_models_path+'/knn_model.sav' + pickle.dump(self.knn_model,open(filename,'wb')) + return + + def get_knn_model(self): + filename = self.post_models_path+'/knn_model.sav' + self.knn_model = pickle.load(open(filename,'rb')) + return + + def seperate_label_from_split(self,split,removed_label:str='artificial_affix'): + + if split in self.splits: + logger.debug(f"splitting {removed_label} from split: {split}") + + + #get art affx + removed_label_df = self.splits_df_dict[f"{split}_df"].loc[self.splits_df_dict[f"{split}_df"][self.label_col]['0'] == removed_label] + + #append art affx as key + self.splits_df_dict[f'{removed_label}_df'] = removed_label_df + #remove art affx from ood + removed_label_ids = self.splits_df_dict[f"{split}_df"].index.isin(removed_label_df.index) + self.splits_df_dict[f"{split}_df"] = self.splits_df_dict[f"{split}_df"][~removed_label_ids].reset_index(drop=True) + + #resetf {split}_affix_idx + self.splits_df_dict[f'{removed_label}_df'] = self.splits_df_dict[f'{removed_label}_df'].reset_index(drop=True) + self.all_splits.append(f'{removed_label}') + + + def append_loco_variants(self): + train_classes = self.splits_df_dict["train_df"]["Logits"].columns.values + if self.mc_flag: + all_loco_classes_df = self.dataset['small_RNA_class_annotation'][self.dataset['small_RNA_class_annotation_hico'].isnull()].str.split(';', expand=True) + else: + all_loco_classes_df = self.dataset['subclass_name'][self.dataset['hico'].isnull()].str.split(';', expand=True) + + all_loco_classes = all_loco_classes_df.values + + #TODO: optimize getting unique values + loco_classes = [] + for col in all_loco_classes_df.columns: + loco_classes.extend(all_loco_classes_df[col].unique()) + + loco_classes = list(set(loco_classes)) + if np.nan in loco_classes: + loco_classes.remove(np.nan) + if None in loco_classes: + loco_classes.remove(None) + + #compute loco not in train mask + loco_classes_not_in_train = list(set(loco_classes).difference(set(train_classes))) + loco_mask_not_in_train_df = all_loco_classes_df.isin(loco_classes_not_in_train) + + + mixed_and_not_in_train_df = all_loco_classes_df.iloc[loco_mask_not_in_train_df.values.sum(axis=1) >= 1] + train_classes_mask = mixed_and_not_in_train_df.isin(train_classes) + + loco_not_in_train_df = mixed_and_not_in_train_df[train_classes_mask.values.sum(axis=1) == 0] + loco_mixed_df = mixed_and_not_in_train_df[~(train_classes_mask.values.sum(axis=1) == 0)] + + nans_and_loco_train_df = all_loco_classes_df.iloc[loco_mask_not_in_train_df.values.sum(axis=1) == 0] + nans_mask = nans_and_loco_train_df.isin([None,np.nan]) + nanas_df = nans_and_loco_train_df[nans_mask.values.sum(axis=1) == len(nans_mask.columns)] + loco_in_train_df = nans_and_loco_train_df[nans_mask.values.sum(axis=1) < len(nans_mask.columns)] + + self.splits_df_dict["loco_not_in_train_df"] = self.splits_df_dict["no_annotation_df"][self.splits_df_dict["no_annotation_df"][self.seq_col]['0'].isin(loco_not_in_train_df.index)] + self.splits_df_dict["loco_mixed_df"] = self.splits_df_dict["no_annotation_df"][self.splits_df_dict["no_annotation_df"][self.seq_col]['0'].isin(loco_mixed_df.index)] + self.splits_df_dict["loco_in_train_df"] = self.splits_df_dict["no_annotation_df"][self.splits_df_dict["no_annotation_df"][self.seq_col]['0'].isin(loco_in_train_df.index)] + self.splits_df_dict["no_annotation_df"] = self.splits_df_dict["no_annotation_df"][self.splits_df_dict["no_annotation_df"][self.seq_col]['0'].isin(nanas_df.index)] + + def get_data(self,path:str,splits:List,ith_run:int = -1): + #results exist in the outputs folder. + #outputs folder has two depth levels, first level indicates day and second indicates time per day + #if path not given, get results from last run + #ith run specifies the last run (-1), second last(-2)... etc + if not path: + files = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '../../..', 'outputs')) + logging.debug(files) + #newest + paths = sorted(list(Path(files).rglob('')), key=lambda x: Path.stat(x).st_mtime, reverse=True) + ith_run = abs(ith_run) + for path in paths: + if str(path).endswith('embedds'): + ith_run-= 1 + if ith_run == 0: + path = str(path) + break + + split_dfs = {} + splits_to_remove = [] + for split in splits: + try: + #read logits csv + split_df = load( + path+f'/{split}_embedds.tsv', + header=[0, 1], + index_col=0, + ) + split_df['split','0'] = split + split_dfs[f"{split}_df"] = split_df + + except: + splits_to_remove.append(split) + logger.info(f'{split} does not exist in embedds! Removing it from splits') + + for split in splits_to_remove: + self.splits.remove(split) + + + return path,split_dfs + + def get_hp_param(self,hp_param): + hp_settings = load(path=self.meta_path+'/hp_settings.yaml') + #hp_param could be in hp_settings .keyes or in a key of a key + hp_val = hp_settings.get(hp_param) + if not hp_val: + for key in hp_settings.keys(): + try: + hp_val = hp_settings[key].get(hp_param) + except: + pass + if hp_val != None: + break + if hp_val == None: + raise ValueError(f"hp_param {hp_param} not found in hp_settings") + + return hp_val + + def load_mc_mapping_dict(self): + mapping_dict_path = self.get_hp_param(hp_param="mapping_dict_path") + + return load(mapping_dict_path) + + def compute_umap(self, + ad, + nn=50, + spread=10, + min_dist=1.0, + ): + sc.tl.pca(ad) + sc.pp.neighbors(ad, n_neighbors=nn, n_pcs=None, use_rep="X_pca") + sc.tl.umap(ad, n_components=2, spread=spread, min_dist=min_dist) + logger.info(f'cords are: {ad.obsm}') + return ad + + + def plot_umap(self,ad, + ncols=3, + colors=['Labels',"Unseen Labels"], + edges=False, + edges_width=0.05, + run_name = None, + path=None, + task=None + ): + sc.set_figure_params(dpi = 80,figsize=[10,10]) + fig = sc.pl.umap( + ad, + ncols=ncols, + color=colors, + edges=edges, + edges_width=edges_width, + title=[f"{run_name} approach: {c} {ad.shape}" for c in colors], + size = ad.obs["size"], + return_fig=True, + save=False + ) + + #fig.savefig(f'{path}{run_name}_{task}_umap.png') + def merge_all_splits(self): + all_dfs = [self.splits_df_dict[f'{split}_df'] for split in self.all_splits] + self.splits_df_dict['all_df'] = pd.concat(all_dfs).reset_index(drop=True) + return + + +def correct_labels(predicted_labels:pd.DataFrame,actual_labels:pd.DataFrame,mapping_dict:Dict): + ''' + This function corrects the predicted labelsfor the bin based sub classes, tRNAs and miRNAs. + First both the actual and predicted labels are converted to major class. There are three classes of major classes: + 1. tRNA: if the actual and predicted agree of all the tRNA sub class name except for the last part after -, then the predicted label is corrected to the actual label + 2. bin based sub classes: if the actual and the predicted agree on sub class and the bin number is within 1 of the actual bin number, then the predicted label is corrected to the actual label + the bin number is after the last - + 3. miRNAs: if the predicted and the actual agree on the first and last part of the subclass, and agree with the numeric part of the middle part, then the predicted label is corrected to the actual label + ''' + if type(predicted_labels) == pd.Series: + predicted_labels = predicted_labels.values + actual_labels = actual_labels.values + import re + corrected_labels = [] + for i in range(len(predicted_labels)): + predicted_label = predicted_labels[i] + actual_label = actual_labels[i] + if predicted_label == actual_label: + corrected_labels.append(predicted_label) + else: + mc = mapping_dict[actual_label] + if mc == 'tRNA': + if mapping_dict[predicted_label] == 'tRNA': + predicted_prec = predicted_label.split('__')[1] + actual_prec = actual_label.split('__')[1] + + #the precursor is split by a -, if both have the share the same first part, then correct + if predicted_prec == actual_prec: + corrected_labels.append(actual_label) + else: + corrected_labels.append(predicted_label) + else: + corrected_labels.append(predicted_label) + elif mc == 'miRNA' and ('mir' in predicted_label.lower() or 'let' in predicted_label.lower()): + #check that the both share the same prime end (either 3p or 5p) + if predicted_label.split('-')[-1] == actual_label.split('-')[-1]: + #check that the both share the same numeric part + predicted_numeric = re.findall(r'\d+', predicted_label.split('-')[1])[0] + actual_numeric = re.findall(r'\d+', actual_label.split('-')[1])[0] + if predicted_numeric == actual_numeric: + corrected_labels.append(actual_label) + else: + corrected_labels.append(predicted_label) + else: + corrected_labels.append(predicted_label) + + elif 'bin' in actual_label: + if '__' in predicted_label and '__' in actual_label: + predicted_label = predicted_label.split('__')[1] + actual_label = actual_label.split('__')[1] + if 'bin' in predicted_label and predicted_label.split('-')[0] == actual_label.split('-')[0]: + #get the bin number + actual_bin = int(actual_label.split('-')[-1]) + predicted_bin = int(predicted_label.split('-')[-1]) + #check that the predicted bin is within 1 of the actual bin + if abs(actual_bin - predicted_bin) <= 1: + corrected_labels.append(actual_label) + else: + corrected_labels.append(predicted_label) + else: + corrected_labels.append(predicted_label) + else: + corrected_labels.append(predicted_label) + return corrected_labels diff --git a/transforna/src/utils/utils.py b/transforna/src/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..42d208c945e70be589cca78db2000de1bdfe3617 --- /dev/null +++ b/transforna/src/utils/utils.py @@ -0,0 +1,577 @@ + +import logging +import math +import os +import random +from pathlib import Path +from random import randint + +import numpy as np +import pandas as pd +import torch +from hydra._internal.utils import _locate +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf +from scipy.stats import entropy +from sklearn.model_selection import train_test_split +from sklearn.utils.class_weight import (compute_class_weight, + compute_sample_weight) +from skorch.dataset import Dataset +from skorch.helper import predefined_split + +from ..callbacks.metrics import get_callbacks +from ..score.score import infer_from_model +from .energy import * +from .file import load + +logger = logging.getLogger(__name__) + +def update_config_with_inference_params(config:DictConfig,mc_or_sc:str='sub_class',trained_on:str = 'id',path_to_models:str = 'models/tcga/') -> DictConfig: + inference_config = config.copy() + model = config['model_name'] + model = "-".join([word.capitalize() for word in model.split("-")]) + transforna_folder = "TransfoRNA_ID" + if trained_on == "full": + transforna_folder = "TransfoRNA_FULL" + + inference_config['inference_settings']["model_path"] = f'{path_to_models}{transforna_folder}/{mc_or_sc}/{model}/ckpt/model_params_tcga.pt' + inference_config["inference"] = True + inference_config["log_logits"] = False + + + inference_config = DictConfig(inference_config) + #train and model config should be fetched from teh inference model + train_cfg_path = get_hp_setting(inference_config, "train_config") + model_cfg_path = get_hp_setting(inference_config, "model_config") + train_config = instantiate(train_cfg_path) + model_config = instantiate(model_cfg_path) + # prepare configs as structured dicts + train_config = OmegaConf.structured(train_config) + model_config = OmegaConf.structured(model_config) + # update model config with the name of the model + model_config["model_input"] = inference_config["model_name"] + inference_config = OmegaConf.merge({"train_config": train_config, "model_config": model_config}, inference_config) + return inference_config + +def update_config_with_dataset_params_benchmark(train_data_df,configs): + ''' + After tokenizing the dataset, some features in the config needs to be updated as they will be used + later by sub modules + ''' + # set feedforward input dimension and vocab size + #ss_tokens_id and tokens_id are the same + configs["model_config"].second_input_token_len = train_data_df["second_input"].shape[1] + configs["model_config"].tokens_len = train_data_df["tokens_id"].shape[1] + #set batch per epoch (number of batches). This will be used later by both the criterion and the LR + configs["train_config"].batch_per_epoch = train_data_df["tokens_id"].shape[0]/configs["train_config"].batch_size + return + +def update_config_with_dataset_params_tcga(dataset_class,all_data_df,configs): + configs["model_config"].ff_input_dim = all_data_df['second_input'].shape[1] + configs["model_config"].vocab_size = len(dataset_class.seq_tokens_ids_dict.keys()) + configs["model_config"].second_input_vocab_size = len(dataset_class.second_input_tokens_ids_dict.keys()) + configs["model_config"].tokens_len = dataset_class.tokens_len + configs["model_config"].second_input_token_len = dataset_class.tokens_len + + if configs["model_name"] == "seq-seq": + configs["model_config"].tokens_len = math.ceil(dataset_class.tokens_len/2) + configs["model_config"].second_input_token_len = math.ceil(dataset_class.tokens_len/2) + + +def update_dataclass_inference(cfg,dataset_class): + seq_token_dict,ss_token_dict = get_tokenization_dicts(cfg) + dataset_class.seq_tokens_ids_dict = seq_token_dict + dataset_class.second_input_tokens_ids_dict = ss_token_dict + dataset_class.tokens_len =cfg["model_config"].tokens_len + dataset_class.max_length = get_hp_setting(cfg,'max_length') + dataset_class.min_length = get_hp_setting(cfg,'min_length') + return dataset_class + +def set_seed_and_device(seed:int = 0,device_no:int=0): + # set seed + torch.backends.cudnn.deterministic = True + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + torch.cuda.set_device(device_no) + #CUDA_LAUNCH_BLOCKING=1 #for debugging + +def sync_skorch_with_config(skorch_cfg: DictConfig,cfg:DictConfig): + ''' + skorch config contains duplicate params to the train and model configs + values of skorch config should be populated by those in the trian and + model config + ''' + + #populate skorch params with params in train or model config if exists + for key in skorch_cfg: + if key in cfg["train_config"]: + skorch_cfg[key] = cfg["train_config"][key] + if key in cfg["model_config"]: + skorch_cfg[key] = cfg["model_config"][key] + + return + +def instantiate_predictor(skorch_cfg: DictConfig,cfg:DictConfig,path: str=None): + # convert config to omegaconf container + predictor_config = OmegaConf.to_container(skorch_cfg) + # Patch model device argument from the run config: + if "device" in predictor_config: + predictor_config["device"] = skorch_cfg["device"] + for key, val in predictor_config.items(): + try: + predictor_config[key] = _locate(val) + except: + continue + #add callbacks to list of params + predictor_config["callbacks"] = get_callbacks(path,cfg) + + + #save callbacks as instantiate changes the lrcallback from tuple to list, + #then skorch's instantiate_callback throws an error + callbacks_list = predictor_config["callbacks"] + predictor_config["callbacks"] = "disable" + + #remove model from the cfg otherwise intantiate will throw an error as + #models' scoring doesnt recieve input params + predictor_config["module__main_config"] = \ + {key:cfg[key] for key in cfg if key not in ["model"]} + #in case of tcga task, remove dataset at it its already instantiated + if 'dataset' in predictor_config['module__main_config']: + del predictor_config['module__main_config']['dataset'] + + #set train split to false in skorch model + if not cfg['train_split']: + predictor_config['train_split'] = False + net = instantiate(predictor_config) + #restore callback and instantiate it + net.callbacks = callbacks_list + net.task = cfg['task'] + net.initialize_callbacks() + #prevents double initialization + net.initialized_=True + return net + +def get_fused_seqs(seqs,num_sequences:int=1,max_len:int=30): + ''' + fuse num_sequences sequences from seqs + ''' + fused_seqs = [] + while len(fused_seqs) < num_sequences: + #get two random sequences + seq1 = random.choice(seqs)[:max_len] + seq2 = random.choice(seqs)[:max_len] + + #select indeex to tuncate seq1 at between 60 to 70% of its length + idx = random.randint(math.floor(len(seq1)*0.3),math.floor(len(seq1)*0.7)) + len_to_be_added_from_seq2 = len(seq1)-idx + #truncate seq1 at idx + seq1 = seq1[:idx] + #get the rest from the beg of seq2 + seq2 = seq2[:len_to_be_added_from_seq2] + #fuse seq1 and seq2 + fused_seq = seq1+seq2 + + if fused_seq not in fused_seqs and fused_seq not in seqs: + fused_seqs.append(fused_seq) + + return fused_seqs + +def revert_seq_tokenization(tokenized_seqs,configs): + window = configs["model_config"].window + if configs["model_config"].tokenizer != "overlap": + logger.error("Sequences are not reverse tokenized") + return tokenized_seqs + + #currently only overlap tokenizer can be reverted + seqs_concat = [] + for seq in tokenized_seqs.values: + seqs_concat.append(''.join(seq[seq!='pad'])[::window]+seq[seq!='pad'][-1][window-1]) + + return pd.DataFrame(seqs_concat,columns=["Sequences"]) + +def introduce_mismatches(seq, n_mismatches): + seq = list(seq) + for i in range(n_mismatches): + rand_nt = randint(0,len(seq)-1) + seq[rand_nt] = ['A','G','C','T'][randint(0,3)] + return ''.join(seq) + +def prepare_split(split_data_df,configs): + ''' + This function returns tokens, token ids and labels for a given dataframes' split. + It also moves tokens and labels to device + ''' + + model_input_cols = ['tokens_id','second_input','seqs_length'] + #token_ids + split_data = torch.tensor( + np.array(split_data_df[model_input_cols].values, dtype=float), + dtype=torch.float, + ) + split_weights = torch.tensor(compute_sample_weight('balanced',split_data_df['Labels'])) + split_data = torch.cat([split_data,split_weights[:,None]],dim=1) + #tokens (chars) + split_rna_seq = revert_seq_tokenization(split_data_df["tokens"],configs) + + #labels + split_labels = torch.tensor( + np.array(split_data_df["Labels"], dtype=int), + dtype=torch.long, + ) + return split_data, split_rna_seq, split_labels + +def prepare_model_inference(cfg,path): + # instantiate skorch model + net = instantiate_predictor(cfg["model"]["skorch_model"], cfg,path) + net.initialize() + + logger.info(f"Model loaded from {cfg['inference_settings']['model_path']}") + net.load_params(f_params=f'{cfg["inference_settings"]["model_path"]}') + net.labels_mapping_dict = dict(zip(cfg["model_config"].class_mappings,list(np.arange(cfg["model_config"].num_classes)))) + #save embeddings + if cfg['log_embedds']: + net.save_embedding=True + net.gene_embedds = [] + net.second_input_embedds = [] + return net + +def prepare_data_benchmark(tokenizer,test_ad, configs): + """ + This function recieves anddata and prepares the anndata in a format suitable for training + It also set default parameters in the config that cannot be known until preprocessing step + is done. + all_data_df is heirarchical pandas dataframe, so can be accessed [AA,AT,..,AC ] + """ + ###get tokenized train set + train_data_df = tokenizer.get_tokenized_data() + + ### update config with data specific params + update_config_with_dataset_params_benchmark(train_data_df,configs) + + ###tokenize test set + test_data_df = tokenize_set(tokenizer,test_ad.var) + + ### get tokens(on device), seqs and labels(on device) + train_data, train_rna_seq, train_labels = prepare_split(train_data_df,configs) + test_data, test_rna_seq, test_labels = prepare_split(test_data_df,configs) + + class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(train_labels.flatten()),y=train_labels.flatten().numpy()) + + + #omegaconfig does not support float64 as datatype so conversion to str is done + # and reconversion is done in criterion + configs['model_config'].class_weights = [str(x) for x in list(class_weights)] + + if configs["train_split"]: + #stratify train to get valid + train_data,valid_data,train_labels,valid_labels = stratify(train_data,train_labels,configs["valid_size"]) + valid_ds = Dataset(valid_data,valid_labels) + valid_ds=predefined_split(valid_ds) + else: + valid_ds = None + + all_data= {"train_data":train_data, + "valid_ds":valid_ds, + "test_data":test_data, + "train_rna_seq":train_rna_seq, + "test_rna_seq":test_rna_seq, + "train_labels_numeric":train_labels, + "test_labels_numeric":test_labels} + + if configs["task"] == "premirna": + generalization_test_set = get_add_test_set(tokenizer,\ + dataset_path=configs["train_config"].datset_path_additional_testset) + + + #get all vocab from both test and train set + configs["model_config"].vocab_size = len(tokenizer.seq_tokens_ids_dict.keys()) + configs["model_config"].second_input_vocab_size = len(tokenizer.second_input_tokens_ids_dict.keys()) + configs["model_config"].tokens_mapping_dict = tokenizer.seq_tokens_ids_dict + + + if configs["task"] == "premirna": + generalization_test_data = [] + for test_df in generalization_test_set: + #no need to read the labels as they are all one + test_data_extra, _, _ = prepare_split(test_df,configs) + generalization_test_data.append(test_data_extra) + all_data["additional_testset"] = generalization_test_data + + #get inference dataset + # if do inference and inference datasert path exists + get_inference_data(configs,tokenizer,all_data) + + return all_data + +def prepare_inference_results_benchmarck(net,cfg,predicted_labels,logits,all_data): + iterables = [["Sequences"], np.arange(1, dtype=int)] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + rna_seqs_df = pd.DataFrame(columns=index, data=np.vstack(all_data["infere_rna_seq"]["Sequences"].values)) + + iterables = [["Logits"], list(net.labels_mapping_dict.keys())] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + logits_df = pd.DataFrame(columns=index, data=np.array(logits)) + + #add Labels,entropy to df + all_data["infere_rna_seq"]["Labels",'0'] = predicted_labels + all_data["infere_rna_seq"].set_index("Sequences",inplace=True) + + #log logits if required + if cfg["log_logits"]: + seq_logits_df = logits_df.join(rna_seqs_df).set_index(("Sequences",0)) + all_data["infere_rna_seq"] = all_data["infere_rna_seq"].join(seq_logits_df) + else: + all_data["infere_rna_seq"].columns = ['Labels'] + + return + +def prepare_inference_results_tcga(cfg,predicted_labels,logits,all_data,max_len): + + logits_clf = load('/'.join(cfg["inference_settings"]["model_path"].split('/')[:-2])\ + +'/analysis/logits_model_coef.yaml') + threshold = round(logits_clf['Threshold'],2) + + + iterables = [["Sequences"], np.arange(1, dtype=int)] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + rna_seqs_df = pd.DataFrame(columns=index, data=np.vstack(all_data["infere_rna_seq"]["Sequences"].values)) + + iterables = [["Logits"], cfg['model_config'].class_mappings] + index = pd.MultiIndex.from_product(iterables, names=["type of data", "indices"]) + logits_df = pd.DataFrame(columns=index, data=np.array(logits)) + + #add Labels,novelty to df + all_data["infere_rna_seq"]["Net-Label"] = predicted_labels + all_data["infere_rna_seq"]["Is Familiar?"] = entropy(logits,axis=1) <= threshold + + all_data["infere_rna_seq"].set_index("Sequences",inplace=True) + + #log logits if required + if cfg["log_logits"]: + seq_logits_df = logits_df.join(rna_seqs_df).set_index(("Sequences",0)) + all_data["infere_rna_seq"] = all_data["infere_rna_seq"].join(seq_logits_df) + + all_data["infere_rna_seq"].index.name = f'Sequences, Max Length={max_len}' + + return + +def prepare_inference_data(cfg,infer_pd,dataset_class): + #tokenize sequences + infere_data_df = tokenize_set(dataset_class,infer_pd,inference=True) + infere_data,infere_rna_seq,_ = prepare_split(infere_data_df,cfg) + + all_data = {} + all_data["infere_data"] = infere_data + all_data["infere_rna_seq"] = infere_rna_seq + return all_data + +def get_inference_data(configs,dataset_class,all_data): + + if configs["inference"]==True and configs["inference_settings"]["sequences_path"] is not None: + inference_file = configs["inference_settings"]["sequences_path"] + inference_path = Path(__file__).parent.parent.parent.absolute() / f"{inference_file}" + + infer_data = load(inference_path) + #check if infer_data has secondary structure + if "Secondary" not in infer_data: + infer_data['Secondary'] = dataset_class.get_secondary_structure(infer_data["Sequences"]) + if "Labels" not in infer_data: + infer_data["Labels"] = [0]*len(infer_data["Sequences"].values) + + dataset_class.seqs_dot_bracket_labels = infer_data + + + dataset_class.min_length = 0 + dataset_class.limit_seqs_to_range(logger) + infere_data_df = dataset_class.get_tokenized_data(inference=True) + infere_data,infere_rna_seq,_ = prepare_split(infere_data_df,configs) + + all_data["infere_data"] = infere_data + all_data["infere_rna_seq"] = infere_rna_seq + +def get_add_test_set(dataset_class,dataset_path): + all_added_test_set = [] + #get paths of all files in mirbase and mirgene + paths_mirbase = dataset_path+"mirbase/" + files_mirbase = os.listdir(paths_mirbase) + for f_idx,_ in enumerate(files_mirbase): + files_mirbase[f_idx] = paths_mirbase+files_mirbase[f_idx] + + paths_mirgene = dataset_path + "mirgene/" + files_mirgene = os.listdir(paths_mirgene) + for f_idx,_ in enumerate(files_mirgene): + files_mirgene[f_idx] = paths_mirgene+files_mirgene[f_idx] + files = files_mirbase+files_mirgene + for f in files: + #tokenize test set + test_pd = load(f) + test_pd = test_pd.drop(columns='Unnamed: 0') + test_pd["Sequences"] = test_pd["Sequences"].astype(object) + test_pd["Secondary"] = test_pd["Secondary"].astype(object) + #convert dataframe to anndata + test_pd["Labels"] = 1 + + dataset_class.seqs_dot_bracket_labels = test_pd + dataset_class.limit_seqs_to_range() + all_added_test_set.append(dataset_class.get_tokenized_data()) + return all_added_test_set + +def get_tokenization_dicts(cfg): + tokenization_path='/'.join(cfg['inference_settings']['model_path'].split('/')[:-2]) + seq_token_dict = load(tokenization_path+'/seq_tokens_ids_dict') + ss_token_dict = load(tokenization_path+'/second_input_tokens_ids_dict') + return seq_token_dict,ss_token_dict + +def get_hp_setting(cfg,hp_param): + model_parent_path=Path('/'.join(cfg['inference_settings']['model_path'].split('/')[:-2])) + hp_settings = load(model_parent_path/'meta/hp_settings.yaml') + + #hp_param could be in hp_settings .keyes or in a key of a key + hp_val = hp_settings.get(hp_param) + if not hp_val: + for key in hp_settings.keys(): + try: + hp_val = hp_settings[key].get(hp_param) + except: + pass + if hp_val != None: + break + if hp_val == None: + raise ValueError(f"hp_param {hp_param} not found in hp_settings") + + return hp_val + +def get_model(cfg,path): + + cfg["model_config"] = get_hp_setting(cfg,'model_config') + + sync_skorch_with_config(cfg["model"]["skorch_model"],cfg) + cfg['model_config']['model_input'] = cfg['model_name'] + net = prepare_model_inference(cfg,path) + return cfg,net + +def stratify(train_data,train_labels,valid_size): + return train_test_split(train_data, train_labels, + stratify=train_labels, + test_size=valid_size) + +def tokenize_set(dataset_class,test_pd,inference:bool=False): + #reassign the sequences to test + dataset_class.seqs_dot_bracket_labels = test_pd + #prevent sequences with len < min lenght from being deleted + dataset_class.limit_seqs_to_range() + return dataset_class.get_tokenized_data(inference) + +def add_original_seqs_to_predictions(short_to_long_df,pred_df): + short_to_long_df.set_index('Sequences',inplace=True) + pred_df = pd.merge(pred_df,short_to_long_df[['Trimmed','Original_Sequence']],right_index=True,left_index=True,how='left') + #filter repeated indexes + pred_df = pred_df[~pred_df.index.duplicated(keep='first')] + return pred_df + +def add_ss_and_labels(infer_data): + #check if infer_data has secondary structure + if "Secondary" not in infer_data: + infer_data["Secondary"] = fold_sequences(infer_data["Sequences"].tolist())['structure_37'].values + if "Labels" not in infer_data: + infer_data["Labels"] = [0]*len(infer_data["Sequences"].values) + return infer_data + +def chunkstring_overlap(string, window): + return ( + string[0 + i : window + i] for i in range(0, len(string) - window + 1, 1) + ) + +def create_short_seqs_from_long(df,max_len): + long_seqs = df['Sequences'][df['Sequences'].str.len()>max_len].values + short_seqs_pd = df[df['Sequences'].str.len()<=max_len] + feature_tokens_gen = list( + chunkstring_overlap(feature, max_len) + for feature in long_seqs + ) + original_seqs = [] + shortened_seqs = [] + for i,feature_tokens in enumerate(feature_tokens_gen): + curr_trimmed_seqs = [feature for feature in feature_tokens] + shortened_seqs.extend(curr_trimmed_seqs) + original_seqs.extend([long_seqs[i]]*len(curr_trimmed_seqs)) + short_to_long_dict = dict(zip(shortened_seqs,original_seqs)) + shortened_df = pd.DataFrame(data=shortened_seqs,columns=['Sequences']) + df = shortened_df.append(short_seqs_pd).reset_index(drop=True) + #add a column in df to indicate if the sequence was trimmed and another column to indicate the original sequence + df['Trimmed'] = False + df.loc[shortened_df.index,'Trimmed'] = True + df['Original_Sequence'] = df['Sequences'] + df.loc[shortened_df.index,'Original_Sequence'] = df.loc[shortened_df.index,'Sequences'].map(short_to_long_dict) + return df + +def infer_from_pd(cfg,net,infer_pd,DataClass,attention_flag:bool=False): + try: + max_len = net.module_.transformer_layers.pos_encoder.pe.shape[1]+1 + except: + max_len = 30#for baseline models + + if cfg['model_name'] == 'seq-seq': + max_len = max_len*2 - 1 + + if len(infer_pd['Sequences'][infer_pd['Sequences'].str.len()>max_len].values)>0: + infer_pd = create_short_seqs_from_long(infer_pd,max_len) + infer_pd = add_ss_and_labels(infer_pd) + if cfg['model_name'] == 'seq-seq': + cfg['model_config']['tokens_len'] *=2 + cfg['model_config']['second_input_token_len'] *=2 + + + #create dataclass to tokenize infer sequences + dataset_class = DataClass(infer_pd,cfg) + #update datasetclass with tokenization dicts and tokens_len + dataset_class = update_dataclass_inference(cfg,dataset_class) + #tokenize sequences + all_data = prepare_inference_data(cfg,infer_pd,dataset_class) + + #inference on custom data + predicted_labels,logits,attn_scores_first_list,attn_scores_second_list = infer_from_model(net,all_data["infere_data"]) + if attention_flag: + #in case of baseline or seq models + if not attn_scores_second_list: + attn_scores_second_list = attn_scores_first_list + + attn_scores_first = np.array(attn_scores_first_list) + seq_lengths = all_data['infere_rna_seq']['Sequences'].str.len().values + #get attention scores for each sequence + attn_scores_list = [attn_scores_first[i,:seq_lengths[i],:seq_lengths[i]].flatten().tolist() for i in range(len(seq_lengths))] + attn_scores_first_df = pd.DataFrame(data = {'attention_first':attn_scores_list}) + attn_scores_first_df.index = all_data['infere_rna_seq']['Sequences'].values + + attn_scores_second = np.array(attn_scores_second_list) + attn_scores_list = [attn_scores_second[i,:seq_lengths[i],:seq_lengths[i]].flatten().tolist() for i in range(len(seq_lengths))] + attn_scores_second_df = pd.DataFrame(data = {'attention_second':attn_scores_list}) + attn_scores_second_df.index = all_data['infere_rna_seq']['Sequences'].values + + attn_scores_df = attn_scores_first_df.join(attn_scores_second_df) + attn_scores_df['Secondary'] = infer_pd["Secondary"].values + else: + attn_scores_df = None + + gene_embedds_df = None + #net.gene_embedds is a list of tensors. convert them to a numpy array + if cfg['log_embedds']: + gene_embedds = np.vstack(net.gene_embedds) + if cfg['model_name'] not in ['baseline']: + second_input_embedds = np.vstack(net.second_input_embedds) + gene_embedds = np.concatenate((gene_embedds,second_input_embedds),axis=1) + gene_embedds_df = pd.DataFrame(data=gene_embedds) + gene_embedds_df.index = all_data['infere_rna_seq']['Sequences'].values + gene_embedds_df.columns = ['gene_embedds_'+str(i) for i in range(gene_embedds_df.shape[1])] + + return predicted_labels,logits,gene_embedds_df,attn_scores_df,all_data,max_len,net,infer_pd + +def log_embedds(cfg,net,seqs_df): + gene_embedds = np.vstack(net.gene_embedds) + if not cfg['model_name'] in ['seq','baseline']: + second_input_embedds = np.vstack(net.second_input_embedds) + gene_embedds = np.concatenate((gene_embedds,second_input_embedds),axis=1) + + return seqs_df.join(pd.DataFrame(data=gene_embedds))