nickmuchi commited on
Commit
50dd923
1 Parent(s): 7fcc2a5

Upload 17 files

Browse files
sentence-transformers/.DS_Store ADDED
Binary file (8.2 kB). View file
 
sentence-transformers/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4
+ Please read the [full text](https://code.fb.com/codeofconduct/)
5
+ so that you can understand what actions will and will not be tolerated.
sentence-transformers/CONTRIBUTING.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to this repo
2
+
3
+ ## Pull Requests
4
+
5
+ In order to accept your pull request, we need you to submit a CLA. You only need
6
+ to do this once to work on any of Facebook's open source projects.
7
+
8
+ Complete your CLA here: <https://code.facebook.com/cla>
9
+
10
+ ## Issues
11
+ We use GitHub issues to track public bugs. Please ensure your description is
12
+ clear and has sufficient instructions to be able to reproduce the issue.
13
+
14
+ ## License
15
+ By contributing to this repo, you agree that your contributions will be licensed
16
+ under the LICENSE file in the root directory of this source tree.
sentence-transformers/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2019 Nils Reimers
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
sentence-transformers/NOTICE.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ -------------------------------------------------------------------------------
2
+ Copyright 2019
3
+ Ubiquitous Knowledge Processing (UKP) Lab
4
+ Technische Universität Darmstadt
5
+ -------------------------------------------------------------------------------
sentence-transformers/README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--- BADGES: START --->
2
+ [![GitHub - License](https://img.shields.io/github/license/UKPLab/sentence-transformers?logo=github&style=flat&color=green)][#github-license]
3
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/sentence-transformers?logo=pypi&style=flat&color=blue)][#pypi-package]
4
+ [![PyPI - Package Version](https://img.shields.io/pypi/v/sentence-transformers?logo=pypi&style=flat&color=orange)][#pypi-package]
5
+ [![Conda - Platform](https://img.shields.io/conda/pn/conda-forge/sentence-transformers?logo=anaconda&style=flat)][#conda-forge-package]
6
+ [![Conda (channel only)](https://img.shields.io/conda/vn/conda-forge/sentence-transformers?logo=anaconda&style=flat&color=orange)][#conda-forge-package]
7
+ [![Docs - GitHub.io](https://img.shields.io/static/v1?logo=github&style=flat&color=pink&label=docs&message=sentence-transformers)][#docs-package]
8
+ <!---
9
+ [![PyPI - Downloads](https://img.shields.io/pypi/dm/sentence-transformers?logo=pypi&style=flat&color=green)][#pypi-package]
10
+ [![Conda](https://img.shields.io/conda/dn/conda-forge/sentence-transformers?logo=anaconda)][#conda-forge-package]
11
+ --->
12
+
13
+ [#github-license]: https://github.com/UKPLab/sentence-transformers/blob/master/LICENSE
14
+ [#pypi-package]: https://pypi.org/project/sentence-transformers/
15
+ [#conda-forge-package]: https://anaconda.org/conda-forge/sentence-transformers
16
+ [#docs-package]: https://www.sbert.net/
17
+ <!--- BADGES: END --->
18
+
19
+ # Sentence Transformers: Multilingual Sentence, Paragraph, and Image Embeddings using BERT & Co.
20
+
21
+ This framework provides an easy method to compute dense vector representations for **sentences**, **paragraphs**, and **images**. The models are based on transformer networks like BERT / RoBERTa / XLM-RoBERTa etc. and achieve state-of-the-art performance in various task. Text is embedding in vector space such that similar text is close and can efficiently be found using cosine similarity.
22
+
23
+ We provide an increasing number of **[state-of-the-art pretrained models](https://www.sbert.net/docs/pretrained_models.html)** for more than 100 languages, fine-tuned for various use-cases.
24
+
25
+ Further, this framework allows an easy **[fine-tuning of custom embeddings models](https://www.sbert.net/docs/training/overview.html)**, to achieve maximal performance on your specific task.
26
+
27
+ For the **full documentation**, see **[www.SBERT.net](https://www.sbert.net)**.
28
+
29
+ The following publications are integrated in this framework:
30
+
31
+ - [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084) (EMNLP 2019)
32
+ - [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813) (EMNLP 2020)
33
+ - [Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks](https://arxiv.org/abs/2010.08240) (NAACL 2021)
34
+ - [The Curse of Dense Low-Dimensional Information Retrieval for Large Index Sizes](https://arxiv.org/abs/2012.14210) (arXiv 2020)
35
+ - [TSDAE: Using Transformer-based Sequential Denoising Auto-Encoder for Unsupervised Sentence Embedding Learning](https://arxiv.org/abs/2104.06979) (arXiv 2021)
36
+ - [BEIR: A Heterogenous Benchmark for Zero-shot Evaluation of Information Retrieval Models](https://arxiv.org/abs/2104.08663) (arXiv 2021)
37
+
38
+ ## Installation
39
+
40
+ We recommend **Python 3.6** or higher, **[PyTorch 1.6.0](https://pytorch.org/get-started/locally/)** or higher and **[transformers v4.6.0](https://github.com/huggingface/transformers)** or higher. The code does **not** work with Python 2.7.
41
+
42
+ **Install with pip**
43
+
44
+ Install the *sentence-transformers* with `pip`:
45
+
46
+ ```
47
+ pip install -U sentence-transformers
48
+ ```
49
+
50
+ **Install with conda**
51
+
52
+ You can install the *sentence-transformers* with `conda`:
53
+
54
+ ```
55
+ conda install -c conda-forge sentence-transformers
56
+ ```
57
+
58
+ **Install from sources**
59
+
60
+ Alternatively, you can also clone the latest version from the [repository](https://github.com/UKPLab/sentence-transformers) and install it directly from the source code:
61
+
62
+ ````
63
+ pip install -e .
64
+ ````
65
+
66
+ **PyTorch with CUDA**
67
+
68
+ If you want to use a GPU / CUDA, you must install PyTorch with the matching CUDA Version. Follow
69
+ [PyTorch - Get Started](https://pytorch.org/get-started/locally/) for further details how to install PyTorch.
70
+
71
+ ## Getting Started
72
+
73
+ See [Quickstart](https://www.sbert.net/docs/quickstart.html) in our documenation.
74
+
75
+ [This example](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/computing-embeddings/computing_embeddings.py) shows you how to use an already trained Sentence Transformer model to embed sentences for another task.
76
+
77
+ First download a pretrained model.
78
+
79
+ ````python
80
+ from sentence_transformers import SentenceTransformer
81
+ model = SentenceTransformer('all-MiniLM-L6-v2')
82
+ ````
83
+
84
+ Then provide some sentences to the model.
85
+
86
+ ````python
87
+ sentences = ['This framework generates embeddings for each input sentence',
88
+ 'Sentences are passed as a list of string.',
89
+ 'The quick brown fox jumps over the lazy dog.']
90
+ sentence_embeddings = model.encode(sentences)
91
+ ````
92
+
93
+ And that's it already. We now have a list of numpy arrays with the embeddings.
94
+
95
+ ````python
96
+ for sentence, embedding in zip(sentences, sentence_embeddings):
97
+ print("Sentence:", sentence)
98
+ print("Embedding:", embedding)
99
+ print("")
100
+ ````
101
+
102
+ ## Pre-Trained Models
103
+
104
+ We provide a large list of [Pretrained Models](https://www.sbert.net/docs/pretrained_models.html) for more than 100 languages. Some models are general purpose models, while others produce embeddings for specific use cases. Pre-trained models can be loaded by just passing the model name: `SentenceTransformer('model_name')`.
105
+
106
+ [» Full list of pretrained models](https://www.sbert.net/docs/pretrained_models.html)
107
+
108
+ ## Training
109
+
110
+ This framework allows you to fine-tune your own sentence embedding methods, so that you get task-specific sentence embeddings. You have various options to choose from in order to get perfect sentence embeddings for your specific task.
111
+
112
+ See [Training Overview](https://www.sbert.net/docs/training/overview.html) for an introduction how to train your own embedding models. We provide [various examples](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training) how to train models on various datasets.
113
+
114
+ Some highlights are:
115
+ - Support of various transformer networks including BERT, RoBERTa, XLM-R, DistilBERT, Electra, BART, ...
116
+ - Multi-Lingual and multi-task learning
117
+ - Evaluation during training to find optimal model
118
+ - [10+ loss-functions](https://www.sbert.net/docs/package_reference/losses.html) allowing to tune models specifically for semantic search, paraphrase mining, semantic similarity comparison, clustering, triplet loss, contrastive loss.
119
+
120
+ ## Performance
121
+
122
+ Our models are evaluated extensively on 15+ datasets including challening domains like Tweets, Reddit, emails. They achieve by far the **best performance** from all available sentence embedding methods. Further, we provide several **smaller models** that are **optimized for speed**.
123
+
124
+ [» Full list of pretrained models](https://www.sbert.net/docs/pretrained_models.html)
125
+
126
+ ## Application Examples
127
+
128
+ You can use this framework for:
129
+
130
+ - [Computing Sentence Embeddings](https://www.sbert.net/examples/applications/computing-embeddings/README.html)
131
+ - [Semantic Textual Similarity](https://www.sbert.net/docs/usage/semantic_textual_similarity.html)
132
+ - [Clustering](https://www.sbert.net/examples/applications/clustering/README.html)
133
+ - [Paraphrase Mining](https://www.sbert.net/examples/applications/paraphrase-mining/README.html)
134
+ - [Translated Sentence Mining](https://www.sbert.net/examples/applications/parallel-sentence-mining/README.html)
135
+ - [Semantic Search](https://www.sbert.net/examples/applications/semantic-search/README.html)
136
+ - [Retrieve & Re-Rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html)
137
+ - [Text Summarization](https://www.sbert.net/examples/applications/text-summarization/README.html)
138
+ - [Multilingual Image Search, Clustering & Duplicate Detection](https://www.sbert.net/examples/applications/image-search/README.html)
139
+
140
+ and many more use-cases.
141
+
142
+ For all examples, see [examples/applications](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications).
143
+
144
+ ## Citing & Authors
145
+
146
+ If you find this repository helpful, feel free to cite our publication [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084):
147
+
148
+ ```bibtex
149
+ @inproceedings{reimers-2019-sentence-bert,
150
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
151
+ author = "Reimers, Nils and Gurevych, Iryna",
152
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
153
+ month = "11",
154
+ year = "2019",
155
+ publisher = "Association for Computational Linguistics",
156
+ url = "https://arxiv.org/abs/1908.10084",
157
+ }
158
+ ```
159
+
160
+ If you use one of the multilingual models, feel free to cite our publication [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813):
161
+
162
+ ```bibtex
163
+ @inproceedings{reimers-2020-multilingual-sentence-bert,
164
+ title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
165
+ author = "Reimers, Nils and Gurevych, Iryna",
166
+ booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
167
+ month = "11",
168
+ year = "2020",
169
+ publisher = "Association for Computational Linguistics",
170
+ url = "https://arxiv.org/abs/2004.09813",
171
+ }
172
+ ```
173
+
174
+ Please have a look at [Publications](https://www.sbert.net/docs/publications.html) for our different publications that are integrated into SentenceTransformers.
175
+
176
+ Contact person: [Nils Reimers](https://www.nils-reimers.de), [info@nils-reimers.de](mailto:info@nils-reimers.de)
177
+
178
+ https://www.ukp.tu-darmstadt.de/
179
+
180
+ Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.
181
+
182
+ > This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.
sentence-transformers/eval_beir.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+ import argparse
9
+ import torch
10
+ import logging
11
+ import json
12
+ import numpy as np
13
+ import os
14
+
15
+ import src.slurm
16
+ import src.contriever
17
+ import src.beir_utils
18
+ import src.utils
19
+ import src.dist_utils
20
+ import src.contriever
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def main(args):
26
+
27
+ src.slurm.init_distributed_mode(args)
28
+ src.slurm.init_signal_handler()
29
+
30
+ os.makedirs(args.output_dir, exist_ok=True)
31
+
32
+ logger = src.utils.init_logger(args)
33
+
34
+ model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
35
+ model = model.cuda()
36
+ model.eval()
37
+ query_encoder = model
38
+ doc_encoder = model
39
+
40
+ logger.info("Start indexing")
41
+
42
+ metrics = src.beir_utils.evaluate_model(
43
+ query_encoder=query_encoder,
44
+ doc_encoder=doc_encoder,
45
+ tokenizer=tokenizer,
46
+ dataset=args.dataset,
47
+ batch_size=args.per_gpu_batch_size,
48
+ norm_query=args.norm_query,
49
+ norm_doc=args.norm_doc,
50
+ is_main=src.dist_utils.is_main(),
51
+ split="dev" if args.dataset == "msmarco" else "test",
52
+ score_function=args.score_function,
53
+ beir_dir=args.beir_dir,
54
+ save_results_path=args.save_results_path,
55
+ lower_case=args.lower_case,
56
+ normalize_text=args.normalize_text,
57
+ )
58
+
59
+ if src.dist_utils.is_main():
60
+ for key, value in metrics.items():
61
+ logger.info(f"{args.dataset} : {key}: {value:.1f}")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
66
+
67
+ parser.add_argument("--dataset", type=str, help="Evaluation dataset from the BEIR benchmark")
68
+ parser.add_argument("--beir_dir", type=str, default="./", help="Directory to save and load beir datasets")
69
+ parser.add_argument("--text_maxlength", type=int, default=512, help="Maximum text length")
70
+
71
+ parser.add_argument("--per_gpu_batch_size", default=128, type=int, help="Batch size per GPU/CPU for indexing.")
72
+ parser.add_argument("--output_dir", type=str, default="./my_experiment", help="Output directory")
73
+ parser.add_argument("--model_name_or_path", type=str, help="Model name or path")
74
+ parser.add_argument(
75
+ "--score_function", type=str, default="dot", help="Metric used to compute similarity between two embeddings"
76
+ )
77
+ parser.add_argument("--norm_query", action="store_true", help="Normalize query representation")
78
+ parser.add_argument("--norm_doc", action="store_true", help="Normalize document representation")
79
+ parser.add_argument("--lower_case", action="store_true", help="lowercase query and document text")
80
+ parser.add_argument(
81
+ "--normalize_text", action="store_true", help="Apply function to normalize some common characters"
82
+ )
83
+ parser.add_argument("--save_results_path", type=str, default=None, help="Path to save result object")
84
+
85
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
86
+ parser.add_argument("--main_port", type=int, default=-1, help="Main port (for multi-node SLURM jobs)")
87
+
88
+ args, _ = parser.parse_known_args()
89
+ main(args)
sentence-transformers/evaluate_retrieved_passages.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import json
9
+ import logging
10
+ import glob
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ import src.utils
16
+
17
+ from src.evaluation import calculate_matches
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def validate(data, workers_num):
22
+ match_stats = calculate_matches(data, workers_num)
23
+ top_k_hits = match_stats.top_k_hits
24
+
25
+ #logger.info('Validation results: top k documents hits %s', top_k_hits)
26
+ top_k_hits = [v / len(data) for v in top_k_hits]
27
+ #logger.info('Validation results: top k documents hits accuracy %s', top_k_hits)
28
+ return top_k_hits
29
+
30
+
31
+ def main(opt):
32
+ logger = src.utils.init_logger(opt, stdout_only=True)
33
+ datapaths = glob.glob(args.data)
34
+ r20, r100 = [], []
35
+ for path in datapaths:
36
+ data = []
37
+ with open(path, 'r') as fin:
38
+ for line in fin:
39
+ data.append(json.loads(line))
40
+ #data = json.load(fin)
41
+ answers = [ex['answers'] for ex in data]
42
+ top_k_hits = validate(data, args.validation_workers)
43
+ message = f"Evaluate results from {path}:"
44
+ for k in [5, 10, 20, 100]:
45
+ if k <= len(top_k_hits):
46
+ recall = 100 * top_k_hits[k-1]
47
+ if k == 20:
48
+ r20.append(f"{recall:.1f}")
49
+ if k == 100:
50
+ r100.append(f"{recall:.1f}")
51
+ message += f' R@{k}: {recall:.1f}'
52
+ logger.info(message)
53
+ print(datapaths)
54
+ print('\t'.join(r20))
55
+ print('\t'.join(r100))
56
+
57
+
58
+ if __name__ == '__main__':
59
+ parser = argparse.ArgumentParser()
60
+
61
+ parser.add_argument('--data', required=True, type=str, default=None)
62
+ parser.add_argument('--validation_workers', type=int, default=16,
63
+ help="Number of parallel processes to validate results")
64
+
65
+ args = parser.parse_args()
66
+ main(args)
sentence-transformers/finetuning.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ import pdb
4
+ import os
5
+ import time
6
+ import sys
7
+ import torch
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ import logging
10
+ import json
11
+ import numpy as np
12
+ import torch.distributed as dist
13
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
14
+
15
+ from src.options import Options
16
+ from src import data, beir_utils, slurm, dist_utils, utils, contriever, finetuning_data, inbatch
17
+
18
+ import train
19
+
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def finetuning(opt, model, optimizer, scheduler, tokenizer, step):
26
+
27
+ run_stats = utils.WeightedAvgStats()
28
+
29
+ tb_logger = utils.init_tb_logger(opt.output_dir)
30
+
31
+ if hasattr(model, "module"):
32
+ eval_model = model.module
33
+ else:
34
+ eval_model = model
35
+ eval_model = eval_model.get_encoder()
36
+
37
+ train_dataset = finetuning_data.Dataset(
38
+ datapaths=opt.train_data,
39
+ negative_ctxs=opt.negative_ctxs,
40
+ negative_hard_ratio=opt.negative_hard_ratio,
41
+ negative_hard_min_idx=opt.negative_hard_min_idx,
42
+ normalize=opt.eval_normalize_text,
43
+ global_rank=dist_utils.get_rank(),
44
+ world_size=dist_utils.get_world_size(),
45
+ maxload=opt.maxload,
46
+ training=True,
47
+ )
48
+ collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length)
49
+ train_sampler = RandomSampler(train_dataset)
50
+ train_dataloader = DataLoader(
51
+ train_dataset,
52
+ sampler=train_sampler,
53
+ batch_size=opt.per_gpu_batch_size,
54
+ drop_last=True,
55
+ num_workers=opt.num_workers,
56
+ collate_fn=collator,
57
+ )
58
+
59
+ train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step)
60
+ evaluate(opt, eval_model, tokenizer, tb_logger, step)
61
+
62
+ epoch = 1
63
+
64
+ model.train()
65
+ prev_ids, prev_mask = None, None
66
+ while step < opt.total_steps:
67
+ logger.info(f"Start epoch {epoch}, number of batches: {len(train_dataloader)}")
68
+ for i, batch in enumerate(train_dataloader):
69
+ batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
70
+ step += 1
71
+
72
+ train_loss, iter_stats = model(**batch, stats_prefix="train")
73
+ train_loss.backward()
74
+
75
+ if opt.optim == "sam" or opt.optim == "asam":
76
+ optimizer.first_step(zero_grad=True)
77
+
78
+ sam_loss, _ = model(**batch, stats_prefix="train/sam_opt")
79
+ sam_loss.backward()
80
+ optimizer.second_step(zero_grad=True)
81
+ else:
82
+ optimizer.step()
83
+ scheduler.step()
84
+ optimizer.zero_grad()
85
+
86
+ run_stats.update(iter_stats)
87
+
88
+ if step % opt.log_freq == 0:
89
+ log = f"{step} / {opt.total_steps}"
90
+ for k, v in sorted(run_stats.average_stats.items()):
91
+ log += f" | {k}: {v:.3f}"
92
+ if tb_logger:
93
+ tb_logger.add_scalar(k, v, step)
94
+ log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}"
95
+ log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB"
96
+
97
+ logger.info(log)
98
+ run_stats.reset()
99
+
100
+ if step % opt.eval_freq == 0:
101
+
102
+ train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step)
103
+ evaluate(opt, eval_model, tokenizer, tb_logger, step)
104
+
105
+ if step % opt.save_freq == 0 and dist_utils.get_rank() == 0:
106
+ utils.save(
107
+ eval_model,
108
+ optimizer,
109
+ scheduler,
110
+ step,
111
+ opt,
112
+ opt.output_dir,
113
+ f"step-{step}",
114
+ )
115
+ model.train()
116
+
117
+ if step >= opt.total_steps:
118
+ break
119
+
120
+ epoch += 1
121
+
122
+
123
+ def evaluate(opt, model, tokenizer, tb_logger, step):
124
+ dataset = finetuning_data.Dataset(
125
+ datapaths=opt.eval_data,
126
+ normalize=opt.eval_normalize_text,
127
+ global_rank=dist_utils.get_rank(),
128
+ world_size=dist_utils.get_world_size(),
129
+ maxload=opt.maxload,
130
+ training=False,
131
+ )
132
+ collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length)
133
+ sampler = SequentialSampler(dataset)
134
+ dataloader = DataLoader(
135
+ dataset,
136
+ sampler=sampler,
137
+ batch_size=opt.per_gpu_batch_size,
138
+ drop_last=False,
139
+ num_workers=opt.num_workers,
140
+ collate_fn=collator,
141
+ )
142
+
143
+ model.eval()
144
+ if hasattr(model, "module"):
145
+ model = model.module
146
+ correct_samples, total_samples, total_step = 0, 0, 0
147
+ all_q, all_g, all_n = [], [], []
148
+ with torch.no_grad():
149
+ for i, batch in enumerate(dataloader):
150
+ batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
151
+
152
+ all_tokens = torch.cat([batch["g_tokens"], batch["n_tokens"]], dim=0)
153
+ all_mask = torch.cat([batch["g_mask"], batch["n_mask"]], dim=0)
154
+
155
+ q_emb = model(input_ids=batch["q_tokens"], attention_mask=batch["q_mask"], normalize=opt.norm_query)
156
+ all_emb = model(input_ids=all_tokens, attention_mask=all_mask, normalize=opt.norm_doc)
157
+
158
+ g_emb, n_emb = torch.split(all_emb, [len(batch["g_tokens"]), len(batch["n_tokens"])])
159
+
160
+ all_q.append(q_emb)
161
+ all_g.append(g_emb)
162
+ all_n.append(n_emb)
163
+
164
+ all_q = torch.cat(all_q, dim=0)
165
+ all_g = torch.cat(all_g, dim=0)
166
+ all_n = torch.cat(all_n, dim=0)
167
+
168
+ labels = torch.arange(0, len(all_q), device=all_q.device, dtype=torch.long)
169
+
170
+ all_sizes = dist_utils.get_varsize(all_g)
171
+ all_g = dist_utils.varsize_gather_nograd(all_g)
172
+ all_n = dist_utils.varsize_gather_nograd(all_n)
173
+ labels = labels + sum(all_sizes[: dist_utils.get_rank()])
174
+
175
+ scores_pos = torch.einsum("id, jd->ij", all_q, all_g)
176
+ scores_neg = torch.einsum("id, jd->ij", all_q, all_n)
177
+ scores = torch.cat([scores_pos, scores_neg], dim=-1)
178
+
179
+ argmax_idx = torch.argmax(scores, dim=1)
180
+ sorted_scores, indices = torch.sort(scores, descending=True)
181
+ isrelevant = indices == labels[:, None]
182
+ rs = [r.cpu().numpy().nonzero()[0] for r in isrelevant]
183
+ mrr = np.mean([1.0 / (r[0] + 1) if r.size else 0.0 for r in rs])
184
+
185
+ acc = (argmax_idx == labels).sum() / all_q.size(0)
186
+ acc, total = dist_utils.weighted_average(acc, all_q.size(0))
187
+ mrr, _ = dist_utils.weighted_average(mrr, all_q.size(0))
188
+ acc = 100 * acc
189
+
190
+ message = []
191
+ if dist_utils.is_main():
192
+ message = [f"eval acc: {acc:.2f}%", f"eval mrr: {mrr:.3f}"]
193
+ logger.info(" | ".join(message))
194
+ if tb_logger is not None:
195
+ tb_logger.add_scalar(f"eval_acc", acc, step)
196
+ tb_logger.add_scalar(f"mrr", mrr, step)
197
+
198
+
199
+ def main():
200
+ logger.info("Start")
201
+
202
+ options = Options()
203
+ opt = options.parse()
204
+
205
+ torch.manual_seed(opt.seed)
206
+ slurm.init_distributed_mode(opt)
207
+ slurm.init_signal_handler()
208
+
209
+ directory_exists = os.path.isdir(opt.output_dir)
210
+ if dist.is_initialized():
211
+ dist.barrier()
212
+ os.makedirs(opt.output_dir, exist_ok=True)
213
+ if not directory_exists and dist_utils.is_main():
214
+ options.print_options(opt)
215
+ if dist.is_initialized():
216
+ dist.barrier()
217
+ utils.init_logger(opt)
218
+
219
+ step = 0
220
+
221
+ retriever, tokenizer, retriever_model_id = contriever.load_retriever(opt.model_path, opt.pooling, opt.random_init)
222
+ opt.retriever_model_id = retriever_model_id
223
+ model = inbatch.InBatch(opt, retriever, tokenizer)
224
+
225
+ model = model.cuda()
226
+
227
+ optimizer, scheduler = utils.set_optim(opt, model)
228
+ # if dist_utils.is_main():
229
+ # utils.save(model, optimizer, scheduler, global_step, 0., opt, opt.output_dir, f"step-{0}")
230
+ logger.info(utils.get_parameters(model))
231
+
232
+ for name, module in model.named_modules():
233
+ if isinstance(module, torch.nn.Dropout):
234
+ module.p = opt.dropout
235
+
236
+ if torch.distributed.is_initialized():
237
+ model = torch.nn.parallel.DistributedDataParallel(
238
+ model,
239
+ device_ids=[opt.local_rank],
240
+ output_device=opt.local_rank,
241
+ find_unused_parameters=False,
242
+ )
243
+
244
+ logger.info("Start training")
245
+ finetuning(opt, model, optimizer, scheduler, tokenizer, step)
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()
sentence-transformers/generate_passage_embeddings.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ import argparse
10
+ import csv
11
+ import logging
12
+ import pickle
13
+
14
+ import numpy as np
15
+ import torch
16
+
17
+ import transformers
18
+
19
+ import src.slurm
20
+ import src.contriever
21
+ import src.utils
22
+ import src.data
23
+ import src.normalize_text
24
+
25
+
26
+ def embed_passages(args, passages, model, tokenizer):
27
+ total = 0
28
+ allids, allembeddings = [], []
29
+ batch_ids, batch_text = [], []
30
+ with torch.no_grad():
31
+ for k, p in enumerate(passages):
32
+ batch_ids.append(p["id"])
33
+ if args.no_title or not "title" in p:
34
+ text = p["text"]
35
+ else:
36
+ text = p["title"] + " " + p["text"]
37
+ if args.lowercase:
38
+ text = text.lower()
39
+ if args.normalize_text:
40
+ text = src.normalize_text.normalize(text)
41
+ batch_text.append(text)
42
+
43
+ if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1:
44
+
45
+ encoded_batch = tokenizer.batch_encode_plus(
46
+ batch_text,
47
+ return_tensors="pt",
48
+ max_length=args.passage_maxlength,
49
+ padding=True,
50
+ truncation=True,
51
+ )
52
+
53
+ encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
54
+ embeddings = model(**encoded_batch)
55
+
56
+ embeddings = embeddings.cpu()
57
+ total += len(batch_ids)
58
+ allids.extend(batch_ids)
59
+ allembeddings.append(embeddings)
60
+
61
+ batch_text = []
62
+ batch_ids = []
63
+ if k % 100000 == 0 and k > 0:
64
+ print(f"Encoded passages {total}")
65
+
66
+ allembeddings = torch.cat(allembeddings, dim=0).numpy()
67
+ return allids, allembeddings
68
+
69
+
70
+ def main(args):
71
+ model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
72
+ print(f"Model loaded from {args.model_name_or_path}.", flush=True)
73
+ model.eval()
74
+ model = model.cuda()
75
+ if not args.no_fp16:
76
+ model = model.half()
77
+
78
+ passages = src.data.load_passages(args.passages)
79
+
80
+ shard_size = len(passages) // args.num_shards
81
+ start_idx = args.shard_id * shard_size
82
+ end_idx = start_idx + shard_size
83
+ if args.shard_id == args.num_shards - 1:
84
+ end_idx = len(passages)
85
+
86
+ passages = passages[start_idx:end_idx]
87
+ print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")
88
+
89
+ allids, allembeddings = embed_passages(args, passages, model, tokenizer)
90
+
91
+ save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}")
92
+ os.makedirs(args.output_dir, exist_ok=True)
93
+ print(f"Saving {len(allids)} passage embeddings to {save_file}.")
94
+ with open(save_file, mode="wb") as f:
95
+ pickle.dump((allids, allembeddings), f)
96
+
97
+ print(f"Total passages processed {len(allids)}. Written to {save_file}.")
98
+
99
+
100
+ if __name__ == "__main__":
101
+ parser = argparse.ArgumentParser()
102
+
103
+ parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
104
+ parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings")
105
+ parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings")
106
+ parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard")
107
+ parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards")
108
+ parser.add_argument(
109
+ "--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass"
110
+ )
111
+ parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage")
112
+ parser.add_argument(
113
+ "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
114
+ )
115
+ parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
116
+ parser.add_argument("--no_title", action="store_true", help="title not added to the passage body")
117
+ parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
118
+ parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding")
119
+
120
+ args = parser.parse_args()
121
+
122
+ src.slurm.init_distributed_mode(args)
123
+
124
+ main(args)
sentence-transformers/index.rst ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SentenceTransformers Documentation
2
+ =================================================
3
+
4
+ SentenceTransformers is a Python framework for state-of-the-art sentence, text and image embeddings. The initial work is described in our paper `Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks <https://arxiv.org/abs/1908.10084>`_.
5
+
6
+ You can use this framework to compute sentence / text embeddings for more than 100 languages. These embeddings can then be compared e.g. with cosine-similarity to find sentences with a similar meaning. This can be useful for `semantic textual similar <docs/usage/semantic_textual_similarity.html>`_, `semantic search <examples/applications/semantic-search/README.html>`_, or `paraphrase mining <examples/applications/paraphrase-mining/README.html>`_.
7
+
8
+ The framework is based on `PyTorch <https://pytorch.org/>`_ and `Transformers <https://huggingface.co/transformers/>`_ and offers a large collection of `pre-trained models <docs/pretrained_models.html>`_ tuned for various tasks. Further, it is easy to `fine-tune your own models <docs/training/overview.html>`_.
9
+
10
+
11
+ Installation
12
+ =================================================
13
+
14
+ You can install it using pip:
15
+
16
+ .. code-block:: python
17
+
18
+ pip install -U sentence-transformers
19
+
20
+
21
+ We recommend **Python 3.6** or higher, and at least **PyTorch 1.6.0**. See `installation <docs/installation.html>`_ for further installation options, especially if you want to use a GPU.
22
+
23
+
24
+
25
+ Usage
26
+ =================================================
27
+ The usage is as simple as:
28
+
29
+ .. code-block:: python
30
+
31
+ from sentence_transformers import SentenceTransformer
32
+ model = SentenceTransformer('all-MiniLM-L6-v2')
33
+
34
+ #Our sentences we like to encode
35
+ sentences = ['This framework generates embeddings for each input sentence',
36
+ 'Sentences are passed as a list of string.',
37
+ 'The quick brown fox jumps over the lazy dog.']
38
+
39
+ #Sentences are encoded by calling model.encode()
40
+ embeddings = model.encode(sentences)
41
+
42
+ #Print the embeddings
43
+ for sentence, embedding in zip(sentences, embeddings):
44
+ print("Sentence:", sentence)
45
+ print("Embedding:", embedding)
46
+ print("")
47
+
48
+
49
+
50
+
51
+ Performance
52
+ =========================
53
+
54
+ Our models are evaluated extensively and achieve state-of-the-art performance on various tasks. Further, the code is tuned to provide the highest possible speed. Have a look at `Pre-Trained Models <https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models/>`_ for an overview of available models and the respective performance on different tasks.
55
+
56
+
57
+
58
+
59
+
60
+
61
+ Contact
62
+ =========================
63
+
64
+ Contact person: Nils Reimers, info@nils-reimers.de
65
+
66
+ https://www.ukp.tu-darmstadt.de/
67
+
68
+
69
+ Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.
70
+
71
+ *This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.*
72
+
73
+
74
+ Citing & Authors
75
+ =========================
76
+
77
+ If you find this repository helpful, feel free to cite our publication `Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks <https://arxiv.org/abs/1908.10084>`_:
78
+
79
+ .. code-block:: bibtex
80
+
81
+ @inproceedings{reimers-2019-sentence-bert,
82
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
83
+ author = "Reimers, Nils and Gurevych, Iryna",
84
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
85
+ month = "11",
86
+ year = "2019",
87
+ publisher = "Association for Computational Linguistics",
88
+ url = "https://arxiv.org/abs/1908.10084",
89
+ }
90
+
91
+
92
+
93
+ If you use one of the multilingual models, feel free to cite our publication `Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation <https://arxiv.org/abs/2004.09813>`_:
94
+
95
+ .. code-block:: bibtex
96
+
97
+ @inproceedings{reimers-2020-multilingual-sentence-bert,
98
+ title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
99
+ author = "Reimers, Nils and Gurevych, Iryna",
100
+ booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
101
+ month = "11",
102
+ year = "2020",
103
+ publisher = "Association for Computational Linguistics",
104
+ url = "https://arxiv.org/abs/2004.09813",
105
+ }
106
+
107
+
108
+
109
+ If you use the code for `data augmentation <https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/data_augmentation>`_, feel free to cite our publication `Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks <https://arxiv.org/abs/2010.08240>`_:
110
+
111
+ .. code-block:: bibtex
112
+
113
+ @inproceedings{thakur-2020-AugSBERT,
114
+ title = "Augmented {SBERT}: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks",
115
+ author = "Thakur, Nandan and Reimers, Nils and Daxenberger, Johannes and Gurevych, Iryna",
116
+ booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies",
117
+ month = jun,
118
+ year = "2021",
119
+ address = "Online",
120
+ publisher = "Association for Computational Linguistics",
121
+ url = "https://www.aclweb.org/anthology/2021.naacl-main.28",
122
+ pages = "296--310",
123
+ }
124
+
125
+
126
+
127
+ .. toctree::
128
+ :maxdepth: 2
129
+ :caption: Overview
130
+
131
+ docs/installation
132
+ docs/quickstart
133
+ docs/pretrained_models
134
+ docs/pretrained_cross-encoders
135
+ docs/publications
136
+ docs/hugging_face
137
+
138
+ .. toctree::
139
+ :maxdepth: 2
140
+ :caption: Usage
141
+
142
+ examples/applications/computing-embeddings/README
143
+ docs/usage/semantic_textual_similarity
144
+ examples/applications/semantic-search/README
145
+ examples/applications/retrieve_rerank/README
146
+ examples/applications/clustering/README
147
+ examples/applications/paraphrase-mining/README
148
+ examples/applications/parallel-sentence-mining/README
149
+ examples/applications/cross-encoder/README
150
+ examples/applications/image-search/README
151
+
152
+ .. toctree::
153
+ :maxdepth: 2
154
+ :caption: Training
155
+
156
+ docs/training/overview
157
+ examples/training/multilingual/README
158
+ examples/training/distillation/README
159
+ examples/training/cross-encoder/README
160
+ examples/training/data_augmentation/README
161
+
162
+ .. toctree::
163
+ :maxdepth: 2
164
+ :caption: Training Examples
165
+
166
+ examples/training/sts/README
167
+ examples/training/nli/README
168
+ examples/training/paraphrases/README
169
+ examples/training/quora_duplicate_questions/README
170
+ examples/training/ms_marco/README
171
+
172
+ .. toctree::
173
+ :maxdepth: 2
174
+ :caption: Unsupervised Learning
175
+
176
+ examples/unsupervised_learning/README
177
+ examples/domain_adaptation/README
178
+
179
+ .. toctree::
180
+ :maxdepth: 1
181
+ :caption: Package Reference
182
+
183
+ docs/package_reference/SentenceTransformer
184
+ docs/package_reference/util
185
+ docs/package_reference/models
186
+ docs/package_reference/losses
187
+ docs/package_reference/evaluation
188
+ docs/package_reference/datasets
189
+ docs/package_reference/cross_encoder
sentence-transformers/passage_retrieval.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import argparse
9
+ import csv
10
+ import json
11
+ import logging
12
+ import pickle
13
+ import time
14
+ import glob
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import torch
19
+ import transformers
20
+
21
+ import src.index
22
+ import src.contriever
23
+ import src.utils
24
+ import src.slurm
25
+ import src.data
26
+ from src.evaluation import calculate_matches
27
+ import src.normalize_text
28
+
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
30
+
31
+
32
+ def embed_queries(args, queries, model, tokenizer):
33
+ model.eval()
34
+ embeddings, batch_question = [], []
35
+ with torch.no_grad():
36
+
37
+ for k, q in enumerate(queries):
38
+ if args.lowercase:
39
+ q = q.lower()
40
+ if args.normalize_text:
41
+ q = src.normalize_text.normalize(q)
42
+ batch_question.append(q)
43
+
44
+ if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
45
+
46
+ encoded_batch = tokenizer.batch_encode_plus(
47
+ batch_question,
48
+ return_tensors="pt",
49
+ max_length=args.question_maxlength,
50
+ padding=True,
51
+ truncation=True,
52
+ )
53
+ encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
54
+ output = model(**encoded_batch)
55
+ embeddings.append(output.cpu())
56
+
57
+ batch_question = []
58
+
59
+ embeddings = torch.cat(embeddings, dim=0)
60
+ print(f"Questions embeddings shape: {embeddings.size()}")
61
+
62
+ return embeddings.numpy()
63
+
64
+
65
+ def index_encoded_data(index, embedding_files, indexing_batch_size):
66
+ allids = []
67
+ allembeddings = np.array([])
68
+ for i, file_path in enumerate(embedding_files):
69
+ print(f"Loading file {file_path}")
70
+ with open(file_path, "rb") as fin:
71
+ ids, embeddings = pickle.load(fin)
72
+
73
+ allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
74
+ allids.extend(ids)
75
+ while allembeddings.shape[0] > indexing_batch_size:
76
+ allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
77
+
78
+ while allembeddings.shape[0] > 0:
79
+ allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
80
+
81
+ print("Data indexing completed.")
82
+
83
+
84
+ def add_embeddings(index, embeddings, ids, indexing_batch_size):
85
+ end_idx = min(indexing_batch_size, embeddings.shape[0])
86
+ ids_toadd = ids[:end_idx]
87
+ embeddings_toadd = embeddings[:end_idx]
88
+ ids = ids[end_idx:]
89
+ embeddings = embeddings[end_idx:]
90
+ index.index_data(ids_toadd, embeddings_toadd)
91
+ return embeddings, ids
92
+
93
+
94
+ def validate(data, workers_num):
95
+ match_stats = calculate_matches(data, workers_num)
96
+ top_k_hits = match_stats.top_k_hits
97
+
98
+ print("Validation results: top k documents hits %s", top_k_hits)
99
+ top_k_hits = [v / len(data) for v in top_k_hits]
100
+ message = ""
101
+ for k in [5, 10, 20, 100]:
102
+ if k <= len(top_k_hits):
103
+ message += f"R@{k}: {top_k_hits[k-1]} "
104
+ print(message)
105
+ return match_stats.questions_doc_hits
106
+
107
+
108
+ def add_passages(data, passages, top_passages_and_scores):
109
+ # add passages to original data
110
+ merged_data = []
111
+ assert len(data) == len(top_passages_and_scores)
112
+ for i, d in enumerate(data):
113
+ results_and_scores = top_passages_and_scores[i]
114
+ docs = [passages[doc_id] for doc_id in results_and_scores[0]]
115
+ scores = [str(score) for score in results_and_scores[1]]
116
+ ctxs_num = len(docs)
117
+ d["ctxs"] = [
118
+ {
119
+ "id": results_and_scores[0][c],
120
+ "title": docs[c]["title"],
121
+ "text": docs[c]["text"],
122
+ "score": scores[c],
123
+ }
124
+ for c in range(ctxs_num)
125
+ ]
126
+
127
+
128
+ def add_hasanswer(data, hasanswer):
129
+ # add hasanswer to data
130
+ for i, ex in enumerate(data):
131
+ for k, d in enumerate(ex["ctxs"]):
132
+ d["hasanswer"] = hasanswer[i][k]
133
+
134
+
135
+ def load_data(data_path):
136
+ if data_path.endswith(".json"):
137
+ with open(data_path, "r") as fin:
138
+ data = json.load(fin)
139
+ elif data_path.endswith(".jsonl"):
140
+ data = []
141
+ with open(data_path, "r") as fin:
142
+ for k, example in enumerate(fin):
143
+ example = json.loads(example)
144
+ data.append(example)
145
+ return data
146
+
147
+
148
+ def main(args):
149
+
150
+ print(f"Loading model from: {args.model_name_or_path}")
151
+ model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
152
+ model.eval()
153
+ model = model.cuda()
154
+ if not args.no_fp16:
155
+ model = model.half()
156
+
157
+ index = src.index.Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
158
+
159
+ # index all passages
160
+ input_paths = glob.glob(args.passages_embeddings)
161
+ input_paths = sorted(input_paths)
162
+ embeddings_dir = os.path.dirname(input_paths[0])
163
+ index_path = os.path.join(embeddings_dir, "index.faiss")
164
+ if args.save_or_load_index and os.path.exists(index_path):
165
+ index.deserialize_from(embeddings_dir)
166
+ else:
167
+ print(f"Indexing passages from files {input_paths}")
168
+ start_time_indexing = time.time()
169
+ index_encoded_data(index, input_paths, args.indexing_batch_size)
170
+ print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
171
+ if args.save_or_load_index:
172
+ index.serialize(embeddings_dir)
173
+
174
+ # load passages
175
+ passages = src.data.load_passages(args.passages)
176
+ passage_id_map = {x["id"]: x for x in passages}
177
+
178
+ data_paths = glob.glob(args.data)
179
+ alldata = []
180
+ for path in data_paths:
181
+ data = load_data(path)
182
+ output_path = os.path.join(args.output_dir, os.path.basename(path))
183
+
184
+ queries = [ex["question"] for ex in data]
185
+ questions_embedding = embed_queries(args, queries, model, tokenizer)
186
+
187
+ # get top k results
188
+ start_time_retrieval = time.time()
189
+ top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
190
+ print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
191
+
192
+ add_passages(data, passage_id_map, top_ids_and_scores)
193
+ hasanswer = validate(data, args.validation_workers)
194
+ add_hasanswer(data, hasanswer)
195
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
196
+ with open(output_path, "w") as fout:
197
+ for ex in data:
198
+ json.dump(ex, fout, ensure_ascii=False)
199
+ fout.write("\n")
200
+ print(f"Saved results to {output_path}")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ parser = argparse.ArgumentParser()
205
+
206
+ parser.add_argument(
207
+ "--data",
208
+ required=True,
209
+ type=str,
210
+ default=None,
211
+ help=".json file containing question and answers, similar format to reader data",
212
+ )
213
+ parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
214
+ parser.add_argument("--passages_embeddings", type=str, default=None, help="Glob path to encoded passages")
215
+ parser.add_argument(
216
+ "--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix"
217
+ )
218
+ parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions")
219
+ parser.add_argument(
220
+ "--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
221
+ )
222
+ parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
223
+ parser.add_argument(
224
+ "--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
225
+ )
226
+ parser.add_argument(
227
+ "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
228
+ )
229
+ parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
230
+ parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
231
+ parser.add_argument(
232
+ "--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
233
+ )
234
+ parser.add_argument("--projection_size", type=int, default=768)
235
+ parser.add_argument(
236
+ "--n_subquantizers",
237
+ type=int,
238
+ default=0,
239
+ help="Number of subquantizer used for vector quantization, if 0 flat index is used",
240
+ )
241
+ parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
242
+ parser.add_argument("--lang", nargs="+")
243
+ parser.add_argument("--dataset", type=str, default="none")
244
+ parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
245
+ parser.add_argument("--normalize_text", action="store_true", help="normalize text")
246
+
247
+ args = parser.parse_args()
248
+ src.slurm.init_distributed_mode(args)
249
+ main(args)
sentence-transformers/preprocess.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ import os
4
+ import argparse
5
+ import torch
6
+
7
+ import transformers
8
+ from src.normalize_text import normalize
9
+
10
+
11
+ def save(tensor, split_path):
12
+ if not os.path.exists(os.path.dirname(split_path)):
13
+ os.makedirs(os.path.dirname(split_path))
14
+ with open(split_path, 'wb') as fout:
15
+ torch.save(tensor, fout)
16
+
17
+ def apply_tokenizer(path, tokenizer, normalize_text=False):
18
+ alltokens = []
19
+ lines = []
20
+ with open(path, "r", encoding="utf-8") as fin:
21
+ for k, line in enumerate(fin):
22
+ if normalize_text:
23
+ line = normalize(line)
24
+
25
+ lines.append(line)
26
+ if len(lines) > 1000000:
27
+ tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids']
28
+ tokens = [torch.tensor(x, dtype=torch.int) for x in tokens]
29
+ alltokens.extend(tokens)
30
+ lines = []
31
+
32
+ tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids']
33
+ tokens = [torch.tensor(x, dtype=torch.int) for x in tokens]
34
+ alltokens.extend(tokens)
35
+
36
+ alltokens = torch.cat(alltokens)
37
+ return alltokens
38
+
39
+ def tokenize_file(args):
40
+ filename = os.path.basename(args.datapath)
41
+ savepath = os.path.join(args.outdir, f"{filename}.pkl")
42
+ if os.path.exists(savepath):
43
+ if args.overwrite:
44
+ print(f"File {savepath} already exists, overwriting")
45
+ else:
46
+ print(f"File {savepath} already exists, exiting")
47
+ return
48
+ try:
49
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=True)
50
+ except:
51
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=False)
52
+ print(f"Encoding {args.datapath}...")
53
+ tokens = apply_tokenizer(args.datapath, tokenizer, normalize_text=args.normalize_text)
54
+
55
+ print(f"Saving at {savepath}...")
56
+ save(tokens, savepath)
57
+
58
+
59
+ if __name__ == '__main__':
60
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
61
+ parser.add_argument("--datapath", type=str)
62
+ parser.add_argument("--outdir", type=str)
63
+ parser.add_argument("--tokenizer", type=str)
64
+ parser.add_argument("--overwrite", action="store_true")
65
+ parser.add_argument("--normalize_text", action="store_true")
66
+
67
+ args, _ = parser.parse_known_args()
68
+ tokenize_file(args)
sentence-transformers/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.6.0,<5.0.0
2
+ tokenizers>=0.10.3
3
+ tqdm
4
+ torch>=1.6.0
5
+ torchvision
6
+ numpy
7
+ scikit-learn
8
+ scipy
9
+ nltk
10
+ sentencepiece
11
+ huggingface-hub
sentence-transformers/setup.cfg ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [metadata]
2
+ description-file = README.md
sentence-transformers/setup.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ with open("README.md", mode="r", encoding="utf-8") as readme_file:
4
+ readme = readme_file.read()
5
+
6
+
7
+
8
+ setup(
9
+ name="sentence-transformers",
10
+ version="2.2.2",
11
+ author="Nils Reimers",
12
+ author_email="info@nils-reimers.de",
13
+ description="Multilingual text embeddings",
14
+ long_description=readme,
15
+ long_description_content_type="text/markdown",
16
+ license="Apache License 2.0",
17
+ url="https://www.SBERT.net",
18
+ download_url="https://github.com/UKPLab/sentence-transformers/",
19
+ packages=find_packages(),
20
+ python_requires=">=3.6.0",
21
+ install_requires=[
22
+ 'transformers>=4.6.0,<5.0.0',
23
+ 'tqdm',
24
+ 'torch>=1.6.0',
25
+ 'torchvision',
26
+ 'numpy',
27
+ 'scikit-learn',
28
+ 'scipy',
29
+ 'nltk',
30
+ 'sentencepiece',
31
+ 'huggingface-hub>=0.4.0'
32
+ ],
33
+ classifiers=[
34
+ "Development Status :: 5 - Production/Stable",
35
+ "Intended Audience :: Science/Research",
36
+ "License :: OSI Approved :: Apache Software License",
37
+ "Programming Language :: Python :: 3.6",
38
+ "Topic :: Scientific/Engineering :: Artificial Intelligence"
39
+ ],
40
+ keywords="Transformer Networks BERT XLNet sentence embedding PyTorch NLP deep learning"
41
+ )
sentence-transformers/train.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ import os
4
+ import time
5
+ import sys
6
+ import torch
7
+ import logging
8
+ import json
9
+ import numpy as np
10
+ import random
11
+ import pickle
12
+
13
+ import torch.distributed as dist
14
+ from torch.utils.data import DataLoader, RandomSampler
15
+
16
+ from src.options import Options
17
+ from src import data, beir_utils, slurm, dist_utils, utils
18
+ from src import moco, inbatch
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def train(opt, model, optimizer, scheduler, step):
25
+
26
+ run_stats = utils.WeightedAvgStats()
27
+
28
+ tb_logger = utils.init_tb_logger(opt.output_dir)
29
+
30
+ logger.info("Data loading")
31
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
32
+ tokenizer = model.module.tokenizer
33
+ else:
34
+ tokenizer = model.tokenizer
35
+ collator = data.Collator(opt=opt)
36
+ train_dataset = data.load_data(opt, tokenizer)
37
+ logger.warning(f"Data loading finished for rank {dist_utils.get_rank()}")
38
+
39
+ train_sampler = RandomSampler(train_dataset)
40
+ train_dataloader = DataLoader(
41
+ train_dataset,
42
+ sampler=train_sampler,
43
+ batch_size=opt.per_gpu_batch_size,
44
+ drop_last=True,
45
+ num_workers=opt.num_workers,
46
+ collate_fn=collator,
47
+ )
48
+
49
+ epoch = 1
50
+
51
+ model.train()
52
+ while step < opt.total_steps:
53
+ train_dataset.generate_offset()
54
+
55
+ logger.info(f"Start epoch {epoch}")
56
+ for i, batch in enumerate(train_dataloader):
57
+ step += 1
58
+
59
+ batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
60
+ train_loss, iter_stats = model(**batch, stats_prefix="train")
61
+
62
+ train_loss.backward()
63
+ optimizer.step()
64
+
65
+ scheduler.step()
66
+ model.zero_grad()
67
+
68
+ run_stats.update(iter_stats)
69
+
70
+ if step % opt.log_freq == 0:
71
+ log = f"{step} / {opt.total_steps}"
72
+ for k, v in sorted(run_stats.average_stats.items()):
73
+ log += f" | {k}: {v:.3f}"
74
+ if tb_logger:
75
+ tb_logger.add_scalar(k, v, step)
76
+ log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}"
77
+ log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB"
78
+
79
+ logger.info(log)
80
+ run_stats.reset()
81
+
82
+ if step % opt.eval_freq == 0:
83
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
84
+ encoder = model.module.get_encoder()
85
+ else:
86
+ encoder = model.get_encoder()
87
+ eval_model(
88
+ opt, query_encoder=encoder, doc_encoder=encoder, tokenizer=tokenizer, tb_logger=tb_logger, step=step
89
+ )
90
+
91
+ if dist_utils.is_main():
92
+ utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"lastlog")
93
+
94
+ model.train()
95
+
96
+ if dist_utils.is_main() and step % opt.save_freq == 0:
97
+ utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"step-{step}")
98
+
99
+ if step > opt.total_steps:
100
+ break
101
+ epoch += 1
102
+
103
+
104
+ def eval_model(opt, query_encoder, doc_encoder, tokenizer, tb_logger, step):
105
+ for datasetname in opt.eval_datasets:
106
+ metrics = beir_utils.evaluate_model(
107
+ query_encoder,
108
+ doc_encoder,
109
+ tokenizer,
110
+ dataset=datasetname,
111
+ batch_size=opt.per_gpu_eval_batch_size,
112
+ norm_doc=opt.norm_doc,
113
+ norm_query=opt.norm_query,
114
+ beir_dir=opt.eval_datasets_dir,
115
+ score_function=opt.score_function,
116
+ lower_case=opt.lower_case,
117
+ normalize_text=opt.eval_normalize_text,
118
+ )
119
+
120
+ message = []
121
+ if dist_utils.is_main():
122
+ for metric in ["NDCG@10", "Recall@10", "Recall@100"]:
123
+ message.append(f"{datasetname}/{metric}: {metrics[metric]:.2f}")
124
+ if tb_logger is not None:
125
+ tb_logger.add_scalar(f"{datasetname}/{metric}", metrics[metric], step)
126
+ logger.info(" | ".join(message))
127
+
128
+
129
+ if __name__ == "__main__":
130
+ logger.info("Start")
131
+
132
+ options = Options()
133
+ opt = options.parse()
134
+
135
+ torch.manual_seed(opt.seed)
136
+ slurm.init_distributed_mode(opt)
137
+ slurm.init_signal_handler()
138
+
139
+ directory_exists = os.path.isdir(opt.output_dir)
140
+ if dist.is_initialized():
141
+ dist.barrier()
142
+ os.makedirs(opt.output_dir, exist_ok=True)
143
+ if not directory_exists and dist_utils.is_main():
144
+ options.print_options(opt)
145
+ if dist.is_initialized():
146
+ dist.barrier()
147
+ utils.init_logger(opt)
148
+
149
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
150
+
151
+ if opt.contrastive_mode == "moco":
152
+ model_class = moco.MoCo
153
+ elif opt.contrastive_mode == "inbatch":
154
+ model_class = inbatch.InBatch
155
+ else:
156
+ raise ValueError(f"contrastive mode: {opt.contrastive_mode} not recognised")
157
+
158
+ if not directory_exists and opt.model_path == "none":
159
+ model = model_class(opt)
160
+ model = model.cuda()
161
+ optimizer, scheduler = utils.set_optim(opt, model)
162
+ step = 0
163
+ elif directory_exists:
164
+ model_path = os.path.join(opt.output_dir, "checkpoint", "latest")
165
+ model, optimizer, scheduler, opt_checkpoint, step = utils.load(
166
+ model_class,
167
+ model_path,
168
+ opt,
169
+ reset_params=False,
170
+ )
171
+ logger.info(f"Model loaded from {opt.output_dir}")
172
+ else:
173
+ model, optimizer, scheduler, opt_checkpoint, step = utils.load(
174
+ model_class,
175
+ opt.model_path,
176
+ opt,
177
+ reset_params=False if opt.continue_training else True,
178
+ )
179
+ if not opt.continue_training:
180
+ step = 0
181
+ logger.info(f"Model loaded from {opt.model_path}")
182
+
183
+ logger.info(utils.get_parameters(model))
184
+
185
+ if dist.is_initialized():
186
+ model = torch.nn.parallel.DistributedDataParallel(
187
+ model,
188
+ device_ids=[opt.local_rank],
189
+ output_device=opt.local_rank,
190
+ find_unused_parameters=False,
191
+ )
192
+ dist.barrier()
193
+
194
+ logger.info("Start training")
195
+ train(opt, model, optimizer, scheduler, step)