Upload 17 files
Browse files- sentence-transformers/.DS_Store +0 -0
- sentence-transformers/CODE_OF_CONDUCT.md +5 -0
- sentence-transformers/CONTRIBUTING.md +16 -0
- sentence-transformers/LICENSE +201 -0
- sentence-transformers/NOTICE.txt +5 -0
- sentence-transformers/README.md +182 -0
- sentence-transformers/eval_beir.py +89 -0
- sentence-transformers/evaluate_retrieved_passages.py +66 -0
- sentence-transformers/finetuning.py +249 -0
- sentence-transformers/generate_passage_embeddings.py +124 -0
- sentence-transformers/index.rst +189 -0
- sentence-transformers/passage_retrieval.py +249 -0
- sentence-transformers/preprocess.py +68 -0
- sentence-transformers/requirements.txt +11 -0
- sentence-transformers/setup.cfg +2 -0
- sentence-transformers/setup.py +41 -0
- sentence-transformers/train.py +195 -0
sentence-transformers/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
sentence-transformers/CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
|
4 |
+
Please read the [full text](https://code.fb.com/codeofconduct/)
|
5 |
+
so that you can understand what actions will and will not be tolerated.
|
sentence-transformers/CONTRIBUTING.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to this repo
|
2 |
+
|
3 |
+
## Pull Requests
|
4 |
+
|
5 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
6 |
+
to do this once to work on any of Facebook's open source projects.
|
7 |
+
|
8 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
9 |
+
|
10 |
+
## Issues
|
11 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
12 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
13 |
+
|
14 |
+
## License
|
15 |
+
By contributing to this repo, you agree that your contributions will be licensed
|
16 |
+
under the LICENSE file in the root directory of this source tree.
|
sentence-transformers/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2019 Nils Reimers
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
sentence-transformers/NOTICE.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-------------------------------------------------------------------------------
|
2 |
+
Copyright 2019
|
3 |
+
Ubiquitous Knowledge Processing (UKP) Lab
|
4 |
+
Technische Universität Darmstadt
|
5 |
+
-------------------------------------------------------------------------------
|
sentence-transformers/README.md
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!--- BADGES: START --->
|
2 |
+
[![GitHub - License](https://img.shields.io/github/license/UKPLab/sentence-transformers?logo=github&style=flat&color=green)][#github-license]
|
3 |
+
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/sentence-transformers?logo=pypi&style=flat&color=blue)][#pypi-package]
|
4 |
+
[![PyPI - Package Version](https://img.shields.io/pypi/v/sentence-transformers?logo=pypi&style=flat&color=orange)][#pypi-package]
|
5 |
+
[![Conda - Platform](https://img.shields.io/conda/pn/conda-forge/sentence-transformers?logo=anaconda&style=flat)][#conda-forge-package]
|
6 |
+
[![Conda (channel only)](https://img.shields.io/conda/vn/conda-forge/sentence-transformers?logo=anaconda&style=flat&color=orange)][#conda-forge-package]
|
7 |
+
[![Docs - GitHub.io](https://img.shields.io/static/v1?logo=github&style=flat&color=pink&label=docs&message=sentence-transformers)][#docs-package]
|
8 |
+
<!---
|
9 |
+
[![PyPI - Downloads](https://img.shields.io/pypi/dm/sentence-transformers?logo=pypi&style=flat&color=green)][#pypi-package]
|
10 |
+
[![Conda](https://img.shields.io/conda/dn/conda-forge/sentence-transformers?logo=anaconda)][#conda-forge-package]
|
11 |
+
--->
|
12 |
+
|
13 |
+
[#github-license]: https://github.com/UKPLab/sentence-transformers/blob/master/LICENSE
|
14 |
+
[#pypi-package]: https://pypi.org/project/sentence-transformers/
|
15 |
+
[#conda-forge-package]: https://anaconda.org/conda-forge/sentence-transformers
|
16 |
+
[#docs-package]: https://www.sbert.net/
|
17 |
+
<!--- BADGES: END --->
|
18 |
+
|
19 |
+
# Sentence Transformers: Multilingual Sentence, Paragraph, and Image Embeddings using BERT & Co.
|
20 |
+
|
21 |
+
This framework provides an easy method to compute dense vector representations for **sentences**, **paragraphs**, and **images**. The models are based on transformer networks like BERT / RoBERTa / XLM-RoBERTa etc. and achieve state-of-the-art performance in various task. Text is embedding in vector space such that similar text is close and can efficiently be found using cosine similarity.
|
22 |
+
|
23 |
+
We provide an increasing number of **[state-of-the-art pretrained models](https://www.sbert.net/docs/pretrained_models.html)** for more than 100 languages, fine-tuned for various use-cases.
|
24 |
+
|
25 |
+
Further, this framework allows an easy **[fine-tuning of custom embeddings models](https://www.sbert.net/docs/training/overview.html)**, to achieve maximal performance on your specific task.
|
26 |
+
|
27 |
+
For the **full documentation**, see **[www.SBERT.net](https://www.sbert.net)**.
|
28 |
+
|
29 |
+
The following publications are integrated in this framework:
|
30 |
+
|
31 |
+
- [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084) (EMNLP 2019)
|
32 |
+
- [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813) (EMNLP 2020)
|
33 |
+
- [Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks](https://arxiv.org/abs/2010.08240) (NAACL 2021)
|
34 |
+
- [The Curse of Dense Low-Dimensional Information Retrieval for Large Index Sizes](https://arxiv.org/abs/2012.14210) (arXiv 2020)
|
35 |
+
- [TSDAE: Using Transformer-based Sequential Denoising Auto-Encoder for Unsupervised Sentence Embedding Learning](https://arxiv.org/abs/2104.06979) (arXiv 2021)
|
36 |
+
- [BEIR: A Heterogenous Benchmark for Zero-shot Evaluation of Information Retrieval Models](https://arxiv.org/abs/2104.08663) (arXiv 2021)
|
37 |
+
|
38 |
+
## Installation
|
39 |
+
|
40 |
+
We recommend **Python 3.6** or higher, **[PyTorch 1.6.0](https://pytorch.org/get-started/locally/)** or higher and **[transformers v4.6.0](https://github.com/huggingface/transformers)** or higher. The code does **not** work with Python 2.7.
|
41 |
+
|
42 |
+
**Install with pip**
|
43 |
+
|
44 |
+
Install the *sentence-transformers* with `pip`:
|
45 |
+
|
46 |
+
```
|
47 |
+
pip install -U sentence-transformers
|
48 |
+
```
|
49 |
+
|
50 |
+
**Install with conda**
|
51 |
+
|
52 |
+
You can install the *sentence-transformers* with `conda`:
|
53 |
+
|
54 |
+
```
|
55 |
+
conda install -c conda-forge sentence-transformers
|
56 |
+
```
|
57 |
+
|
58 |
+
**Install from sources**
|
59 |
+
|
60 |
+
Alternatively, you can also clone the latest version from the [repository](https://github.com/UKPLab/sentence-transformers) and install it directly from the source code:
|
61 |
+
|
62 |
+
````
|
63 |
+
pip install -e .
|
64 |
+
````
|
65 |
+
|
66 |
+
**PyTorch with CUDA**
|
67 |
+
|
68 |
+
If you want to use a GPU / CUDA, you must install PyTorch with the matching CUDA Version. Follow
|
69 |
+
[PyTorch - Get Started](https://pytorch.org/get-started/locally/) for further details how to install PyTorch.
|
70 |
+
|
71 |
+
## Getting Started
|
72 |
+
|
73 |
+
See [Quickstart](https://www.sbert.net/docs/quickstart.html) in our documenation.
|
74 |
+
|
75 |
+
[This example](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/computing-embeddings/computing_embeddings.py) shows you how to use an already trained Sentence Transformer model to embed sentences for another task.
|
76 |
+
|
77 |
+
First download a pretrained model.
|
78 |
+
|
79 |
+
````python
|
80 |
+
from sentence_transformers import SentenceTransformer
|
81 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
82 |
+
````
|
83 |
+
|
84 |
+
Then provide some sentences to the model.
|
85 |
+
|
86 |
+
````python
|
87 |
+
sentences = ['This framework generates embeddings for each input sentence',
|
88 |
+
'Sentences are passed as a list of string.',
|
89 |
+
'The quick brown fox jumps over the lazy dog.']
|
90 |
+
sentence_embeddings = model.encode(sentences)
|
91 |
+
````
|
92 |
+
|
93 |
+
And that's it already. We now have a list of numpy arrays with the embeddings.
|
94 |
+
|
95 |
+
````python
|
96 |
+
for sentence, embedding in zip(sentences, sentence_embeddings):
|
97 |
+
print("Sentence:", sentence)
|
98 |
+
print("Embedding:", embedding)
|
99 |
+
print("")
|
100 |
+
````
|
101 |
+
|
102 |
+
## Pre-Trained Models
|
103 |
+
|
104 |
+
We provide a large list of [Pretrained Models](https://www.sbert.net/docs/pretrained_models.html) for more than 100 languages. Some models are general purpose models, while others produce embeddings for specific use cases. Pre-trained models can be loaded by just passing the model name: `SentenceTransformer('model_name')`.
|
105 |
+
|
106 |
+
[» Full list of pretrained models](https://www.sbert.net/docs/pretrained_models.html)
|
107 |
+
|
108 |
+
## Training
|
109 |
+
|
110 |
+
This framework allows you to fine-tune your own sentence embedding methods, so that you get task-specific sentence embeddings. You have various options to choose from in order to get perfect sentence embeddings for your specific task.
|
111 |
+
|
112 |
+
See [Training Overview](https://www.sbert.net/docs/training/overview.html) for an introduction how to train your own embedding models. We provide [various examples](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training) how to train models on various datasets.
|
113 |
+
|
114 |
+
Some highlights are:
|
115 |
+
- Support of various transformer networks including BERT, RoBERTa, XLM-R, DistilBERT, Electra, BART, ...
|
116 |
+
- Multi-Lingual and multi-task learning
|
117 |
+
- Evaluation during training to find optimal model
|
118 |
+
- [10+ loss-functions](https://www.sbert.net/docs/package_reference/losses.html) allowing to tune models specifically for semantic search, paraphrase mining, semantic similarity comparison, clustering, triplet loss, contrastive loss.
|
119 |
+
|
120 |
+
## Performance
|
121 |
+
|
122 |
+
Our models are evaluated extensively on 15+ datasets including challening domains like Tweets, Reddit, emails. They achieve by far the **best performance** from all available sentence embedding methods. Further, we provide several **smaller models** that are **optimized for speed**.
|
123 |
+
|
124 |
+
[» Full list of pretrained models](https://www.sbert.net/docs/pretrained_models.html)
|
125 |
+
|
126 |
+
## Application Examples
|
127 |
+
|
128 |
+
You can use this framework for:
|
129 |
+
|
130 |
+
- [Computing Sentence Embeddings](https://www.sbert.net/examples/applications/computing-embeddings/README.html)
|
131 |
+
- [Semantic Textual Similarity](https://www.sbert.net/docs/usage/semantic_textual_similarity.html)
|
132 |
+
- [Clustering](https://www.sbert.net/examples/applications/clustering/README.html)
|
133 |
+
- [Paraphrase Mining](https://www.sbert.net/examples/applications/paraphrase-mining/README.html)
|
134 |
+
- [Translated Sentence Mining](https://www.sbert.net/examples/applications/parallel-sentence-mining/README.html)
|
135 |
+
- [Semantic Search](https://www.sbert.net/examples/applications/semantic-search/README.html)
|
136 |
+
- [Retrieve & Re-Rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html)
|
137 |
+
- [Text Summarization](https://www.sbert.net/examples/applications/text-summarization/README.html)
|
138 |
+
- [Multilingual Image Search, Clustering & Duplicate Detection](https://www.sbert.net/examples/applications/image-search/README.html)
|
139 |
+
|
140 |
+
and many more use-cases.
|
141 |
+
|
142 |
+
For all examples, see [examples/applications](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications).
|
143 |
+
|
144 |
+
## Citing & Authors
|
145 |
+
|
146 |
+
If you find this repository helpful, feel free to cite our publication [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084):
|
147 |
+
|
148 |
+
```bibtex
|
149 |
+
@inproceedings{reimers-2019-sentence-bert,
|
150 |
+
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
|
151 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
152 |
+
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
|
153 |
+
month = "11",
|
154 |
+
year = "2019",
|
155 |
+
publisher = "Association for Computational Linguistics",
|
156 |
+
url = "https://arxiv.org/abs/1908.10084",
|
157 |
+
}
|
158 |
+
```
|
159 |
+
|
160 |
+
If you use one of the multilingual models, feel free to cite our publication [Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation](https://arxiv.org/abs/2004.09813):
|
161 |
+
|
162 |
+
```bibtex
|
163 |
+
@inproceedings{reimers-2020-multilingual-sentence-bert,
|
164 |
+
title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
|
165 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
166 |
+
booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
|
167 |
+
month = "11",
|
168 |
+
year = "2020",
|
169 |
+
publisher = "Association for Computational Linguistics",
|
170 |
+
url = "https://arxiv.org/abs/2004.09813",
|
171 |
+
}
|
172 |
+
```
|
173 |
+
|
174 |
+
Please have a look at [Publications](https://www.sbert.net/docs/publications.html) for our different publications that are integrated into SentenceTransformers.
|
175 |
+
|
176 |
+
Contact person: [Nils Reimers](https://www.nils-reimers.de), [info@nils-reimers.de](mailto:info@nils-reimers.de)
|
177 |
+
|
178 |
+
https://www.ukp.tu-darmstadt.de/
|
179 |
+
|
180 |
+
Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.
|
181 |
+
|
182 |
+
> This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.
|
sentence-transformers/eval_beir.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import sys
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
import logging
|
11 |
+
import json
|
12 |
+
import numpy as np
|
13 |
+
import os
|
14 |
+
|
15 |
+
import src.slurm
|
16 |
+
import src.contriever
|
17 |
+
import src.beir_utils
|
18 |
+
import src.utils
|
19 |
+
import src.dist_utils
|
20 |
+
import src.contriever
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def main(args):
|
26 |
+
|
27 |
+
src.slurm.init_distributed_mode(args)
|
28 |
+
src.slurm.init_signal_handler()
|
29 |
+
|
30 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
31 |
+
|
32 |
+
logger = src.utils.init_logger(args)
|
33 |
+
|
34 |
+
model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
|
35 |
+
model = model.cuda()
|
36 |
+
model.eval()
|
37 |
+
query_encoder = model
|
38 |
+
doc_encoder = model
|
39 |
+
|
40 |
+
logger.info("Start indexing")
|
41 |
+
|
42 |
+
metrics = src.beir_utils.evaluate_model(
|
43 |
+
query_encoder=query_encoder,
|
44 |
+
doc_encoder=doc_encoder,
|
45 |
+
tokenizer=tokenizer,
|
46 |
+
dataset=args.dataset,
|
47 |
+
batch_size=args.per_gpu_batch_size,
|
48 |
+
norm_query=args.norm_query,
|
49 |
+
norm_doc=args.norm_doc,
|
50 |
+
is_main=src.dist_utils.is_main(),
|
51 |
+
split="dev" if args.dataset == "msmarco" else "test",
|
52 |
+
score_function=args.score_function,
|
53 |
+
beir_dir=args.beir_dir,
|
54 |
+
save_results_path=args.save_results_path,
|
55 |
+
lower_case=args.lower_case,
|
56 |
+
normalize_text=args.normalize_text,
|
57 |
+
)
|
58 |
+
|
59 |
+
if src.dist_utils.is_main():
|
60 |
+
for key, value in metrics.items():
|
61 |
+
logger.info(f"{args.dataset} : {key}: {value:.1f}")
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
66 |
+
|
67 |
+
parser.add_argument("--dataset", type=str, help="Evaluation dataset from the BEIR benchmark")
|
68 |
+
parser.add_argument("--beir_dir", type=str, default="./", help="Directory to save and load beir datasets")
|
69 |
+
parser.add_argument("--text_maxlength", type=int, default=512, help="Maximum text length")
|
70 |
+
|
71 |
+
parser.add_argument("--per_gpu_batch_size", default=128, type=int, help="Batch size per GPU/CPU for indexing.")
|
72 |
+
parser.add_argument("--output_dir", type=str, default="./my_experiment", help="Output directory")
|
73 |
+
parser.add_argument("--model_name_or_path", type=str, help="Model name or path")
|
74 |
+
parser.add_argument(
|
75 |
+
"--score_function", type=str, default="dot", help="Metric used to compute similarity between two embeddings"
|
76 |
+
)
|
77 |
+
parser.add_argument("--norm_query", action="store_true", help="Normalize query representation")
|
78 |
+
parser.add_argument("--norm_doc", action="store_true", help="Normalize document representation")
|
79 |
+
parser.add_argument("--lower_case", action="store_true", help="lowercase query and document text")
|
80 |
+
parser.add_argument(
|
81 |
+
"--normalize_text", action="store_true", help="Apply function to normalize some common characters"
|
82 |
+
)
|
83 |
+
parser.add_argument("--save_results_path", type=str, default=None, help="Path to save result object")
|
84 |
+
|
85 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
86 |
+
parser.add_argument("--main_port", type=int, default=-1, help="Main port (for multi-node SLURM jobs)")
|
87 |
+
|
88 |
+
args, _ = parser.parse_known_args()
|
89 |
+
main(args)
|
sentence-transformers/evaluate_retrieved_passages.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import glob
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
import src.utils
|
16 |
+
|
17 |
+
from src.evaluation import calculate_matches
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
def validate(data, workers_num):
|
22 |
+
match_stats = calculate_matches(data, workers_num)
|
23 |
+
top_k_hits = match_stats.top_k_hits
|
24 |
+
|
25 |
+
#logger.info('Validation results: top k documents hits %s', top_k_hits)
|
26 |
+
top_k_hits = [v / len(data) for v in top_k_hits]
|
27 |
+
#logger.info('Validation results: top k documents hits accuracy %s', top_k_hits)
|
28 |
+
return top_k_hits
|
29 |
+
|
30 |
+
|
31 |
+
def main(opt):
|
32 |
+
logger = src.utils.init_logger(opt, stdout_only=True)
|
33 |
+
datapaths = glob.glob(args.data)
|
34 |
+
r20, r100 = [], []
|
35 |
+
for path in datapaths:
|
36 |
+
data = []
|
37 |
+
with open(path, 'r') as fin:
|
38 |
+
for line in fin:
|
39 |
+
data.append(json.loads(line))
|
40 |
+
#data = json.load(fin)
|
41 |
+
answers = [ex['answers'] for ex in data]
|
42 |
+
top_k_hits = validate(data, args.validation_workers)
|
43 |
+
message = f"Evaluate results from {path}:"
|
44 |
+
for k in [5, 10, 20, 100]:
|
45 |
+
if k <= len(top_k_hits):
|
46 |
+
recall = 100 * top_k_hits[k-1]
|
47 |
+
if k == 20:
|
48 |
+
r20.append(f"{recall:.1f}")
|
49 |
+
if k == 100:
|
50 |
+
r100.append(f"{recall:.1f}")
|
51 |
+
message += f' R@{k}: {recall:.1f}'
|
52 |
+
logger.info(message)
|
53 |
+
print(datapaths)
|
54 |
+
print('\t'.join(r20))
|
55 |
+
print('\t'.join(r100))
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == '__main__':
|
59 |
+
parser = argparse.ArgumentParser()
|
60 |
+
|
61 |
+
parser.add_argument('--data', required=True, type=str, default=None)
|
62 |
+
parser.add_argument('--validation_workers', type=int, default=16,
|
63 |
+
help="Number of parallel processes to validate results")
|
64 |
+
|
65 |
+
args = parser.parse_args()
|
66 |
+
main(args)
|
sentence-transformers/finetuning.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
|
3 |
+
import pdb
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
import sys
|
7 |
+
import torch
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
import logging
|
10 |
+
import json
|
11 |
+
import numpy as np
|
12 |
+
import torch.distributed as dist
|
13 |
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
14 |
+
|
15 |
+
from src.options import Options
|
16 |
+
from src import data, beir_utils, slurm, dist_utils, utils, contriever, finetuning_data, inbatch
|
17 |
+
|
18 |
+
import train
|
19 |
+
|
20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def finetuning(opt, model, optimizer, scheduler, tokenizer, step):
|
26 |
+
|
27 |
+
run_stats = utils.WeightedAvgStats()
|
28 |
+
|
29 |
+
tb_logger = utils.init_tb_logger(opt.output_dir)
|
30 |
+
|
31 |
+
if hasattr(model, "module"):
|
32 |
+
eval_model = model.module
|
33 |
+
else:
|
34 |
+
eval_model = model
|
35 |
+
eval_model = eval_model.get_encoder()
|
36 |
+
|
37 |
+
train_dataset = finetuning_data.Dataset(
|
38 |
+
datapaths=opt.train_data,
|
39 |
+
negative_ctxs=opt.negative_ctxs,
|
40 |
+
negative_hard_ratio=opt.negative_hard_ratio,
|
41 |
+
negative_hard_min_idx=opt.negative_hard_min_idx,
|
42 |
+
normalize=opt.eval_normalize_text,
|
43 |
+
global_rank=dist_utils.get_rank(),
|
44 |
+
world_size=dist_utils.get_world_size(),
|
45 |
+
maxload=opt.maxload,
|
46 |
+
training=True,
|
47 |
+
)
|
48 |
+
collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length)
|
49 |
+
train_sampler = RandomSampler(train_dataset)
|
50 |
+
train_dataloader = DataLoader(
|
51 |
+
train_dataset,
|
52 |
+
sampler=train_sampler,
|
53 |
+
batch_size=opt.per_gpu_batch_size,
|
54 |
+
drop_last=True,
|
55 |
+
num_workers=opt.num_workers,
|
56 |
+
collate_fn=collator,
|
57 |
+
)
|
58 |
+
|
59 |
+
train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step)
|
60 |
+
evaluate(opt, eval_model, tokenizer, tb_logger, step)
|
61 |
+
|
62 |
+
epoch = 1
|
63 |
+
|
64 |
+
model.train()
|
65 |
+
prev_ids, prev_mask = None, None
|
66 |
+
while step < opt.total_steps:
|
67 |
+
logger.info(f"Start epoch {epoch}, number of batches: {len(train_dataloader)}")
|
68 |
+
for i, batch in enumerate(train_dataloader):
|
69 |
+
batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
|
70 |
+
step += 1
|
71 |
+
|
72 |
+
train_loss, iter_stats = model(**batch, stats_prefix="train")
|
73 |
+
train_loss.backward()
|
74 |
+
|
75 |
+
if opt.optim == "sam" or opt.optim == "asam":
|
76 |
+
optimizer.first_step(zero_grad=True)
|
77 |
+
|
78 |
+
sam_loss, _ = model(**batch, stats_prefix="train/sam_opt")
|
79 |
+
sam_loss.backward()
|
80 |
+
optimizer.second_step(zero_grad=True)
|
81 |
+
else:
|
82 |
+
optimizer.step()
|
83 |
+
scheduler.step()
|
84 |
+
optimizer.zero_grad()
|
85 |
+
|
86 |
+
run_stats.update(iter_stats)
|
87 |
+
|
88 |
+
if step % opt.log_freq == 0:
|
89 |
+
log = f"{step} / {opt.total_steps}"
|
90 |
+
for k, v in sorted(run_stats.average_stats.items()):
|
91 |
+
log += f" | {k}: {v:.3f}"
|
92 |
+
if tb_logger:
|
93 |
+
tb_logger.add_scalar(k, v, step)
|
94 |
+
log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}"
|
95 |
+
log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB"
|
96 |
+
|
97 |
+
logger.info(log)
|
98 |
+
run_stats.reset()
|
99 |
+
|
100 |
+
if step % opt.eval_freq == 0:
|
101 |
+
|
102 |
+
train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step)
|
103 |
+
evaluate(opt, eval_model, tokenizer, tb_logger, step)
|
104 |
+
|
105 |
+
if step % opt.save_freq == 0 and dist_utils.get_rank() == 0:
|
106 |
+
utils.save(
|
107 |
+
eval_model,
|
108 |
+
optimizer,
|
109 |
+
scheduler,
|
110 |
+
step,
|
111 |
+
opt,
|
112 |
+
opt.output_dir,
|
113 |
+
f"step-{step}",
|
114 |
+
)
|
115 |
+
model.train()
|
116 |
+
|
117 |
+
if step >= opt.total_steps:
|
118 |
+
break
|
119 |
+
|
120 |
+
epoch += 1
|
121 |
+
|
122 |
+
|
123 |
+
def evaluate(opt, model, tokenizer, tb_logger, step):
|
124 |
+
dataset = finetuning_data.Dataset(
|
125 |
+
datapaths=opt.eval_data,
|
126 |
+
normalize=opt.eval_normalize_text,
|
127 |
+
global_rank=dist_utils.get_rank(),
|
128 |
+
world_size=dist_utils.get_world_size(),
|
129 |
+
maxload=opt.maxload,
|
130 |
+
training=False,
|
131 |
+
)
|
132 |
+
collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length)
|
133 |
+
sampler = SequentialSampler(dataset)
|
134 |
+
dataloader = DataLoader(
|
135 |
+
dataset,
|
136 |
+
sampler=sampler,
|
137 |
+
batch_size=opt.per_gpu_batch_size,
|
138 |
+
drop_last=False,
|
139 |
+
num_workers=opt.num_workers,
|
140 |
+
collate_fn=collator,
|
141 |
+
)
|
142 |
+
|
143 |
+
model.eval()
|
144 |
+
if hasattr(model, "module"):
|
145 |
+
model = model.module
|
146 |
+
correct_samples, total_samples, total_step = 0, 0, 0
|
147 |
+
all_q, all_g, all_n = [], [], []
|
148 |
+
with torch.no_grad():
|
149 |
+
for i, batch in enumerate(dataloader):
|
150 |
+
batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
|
151 |
+
|
152 |
+
all_tokens = torch.cat([batch["g_tokens"], batch["n_tokens"]], dim=0)
|
153 |
+
all_mask = torch.cat([batch["g_mask"], batch["n_mask"]], dim=0)
|
154 |
+
|
155 |
+
q_emb = model(input_ids=batch["q_tokens"], attention_mask=batch["q_mask"], normalize=opt.norm_query)
|
156 |
+
all_emb = model(input_ids=all_tokens, attention_mask=all_mask, normalize=opt.norm_doc)
|
157 |
+
|
158 |
+
g_emb, n_emb = torch.split(all_emb, [len(batch["g_tokens"]), len(batch["n_tokens"])])
|
159 |
+
|
160 |
+
all_q.append(q_emb)
|
161 |
+
all_g.append(g_emb)
|
162 |
+
all_n.append(n_emb)
|
163 |
+
|
164 |
+
all_q = torch.cat(all_q, dim=0)
|
165 |
+
all_g = torch.cat(all_g, dim=0)
|
166 |
+
all_n = torch.cat(all_n, dim=0)
|
167 |
+
|
168 |
+
labels = torch.arange(0, len(all_q), device=all_q.device, dtype=torch.long)
|
169 |
+
|
170 |
+
all_sizes = dist_utils.get_varsize(all_g)
|
171 |
+
all_g = dist_utils.varsize_gather_nograd(all_g)
|
172 |
+
all_n = dist_utils.varsize_gather_nograd(all_n)
|
173 |
+
labels = labels + sum(all_sizes[: dist_utils.get_rank()])
|
174 |
+
|
175 |
+
scores_pos = torch.einsum("id, jd->ij", all_q, all_g)
|
176 |
+
scores_neg = torch.einsum("id, jd->ij", all_q, all_n)
|
177 |
+
scores = torch.cat([scores_pos, scores_neg], dim=-1)
|
178 |
+
|
179 |
+
argmax_idx = torch.argmax(scores, dim=1)
|
180 |
+
sorted_scores, indices = torch.sort(scores, descending=True)
|
181 |
+
isrelevant = indices == labels[:, None]
|
182 |
+
rs = [r.cpu().numpy().nonzero()[0] for r in isrelevant]
|
183 |
+
mrr = np.mean([1.0 / (r[0] + 1) if r.size else 0.0 for r in rs])
|
184 |
+
|
185 |
+
acc = (argmax_idx == labels).sum() / all_q.size(0)
|
186 |
+
acc, total = dist_utils.weighted_average(acc, all_q.size(0))
|
187 |
+
mrr, _ = dist_utils.weighted_average(mrr, all_q.size(0))
|
188 |
+
acc = 100 * acc
|
189 |
+
|
190 |
+
message = []
|
191 |
+
if dist_utils.is_main():
|
192 |
+
message = [f"eval acc: {acc:.2f}%", f"eval mrr: {mrr:.3f}"]
|
193 |
+
logger.info(" | ".join(message))
|
194 |
+
if tb_logger is not None:
|
195 |
+
tb_logger.add_scalar(f"eval_acc", acc, step)
|
196 |
+
tb_logger.add_scalar(f"mrr", mrr, step)
|
197 |
+
|
198 |
+
|
199 |
+
def main():
|
200 |
+
logger.info("Start")
|
201 |
+
|
202 |
+
options = Options()
|
203 |
+
opt = options.parse()
|
204 |
+
|
205 |
+
torch.manual_seed(opt.seed)
|
206 |
+
slurm.init_distributed_mode(opt)
|
207 |
+
slurm.init_signal_handler()
|
208 |
+
|
209 |
+
directory_exists = os.path.isdir(opt.output_dir)
|
210 |
+
if dist.is_initialized():
|
211 |
+
dist.barrier()
|
212 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
213 |
+
if not directory_exists and dist_utils.is_main():
|
214 |
+
options.print_options(opt)
|
215 |
+
if dist.is_initialized():
|
216 |
+
dist.barrier()
|
217 |
+
utils.init_logger(opt)
|
218 |
+
|
219 |
+
step = 0
|
220 |
+
|
221 |
+
retriever, tokenizer, retriever_model_id = contriever.load_retriever(opt.model_path, opt.pooling, opt.random_init)
|
222 |
+
opt.retriever_model_id = retriever_model_id
|
223 |
+
model = inbatch.InBatch(opt, retriever, tokenizer)
|
224 |
+
|
225 |
+
model = model.cuda()
|
226 |
+
|
227 |
+
optimizer, scheduler = utils.set_optim(opt, model)
|
228 |
+
# if dist_utils.is_main():
|
229 |
+
# utils.save(model, optimizer, scheduler, global_step, 0., opt, opt.output_dir, f"step-{0}")
|
230 |
+
logger.info(utils.get_parameters(model))
|
231 |
+
|
232 |
+
for name, module in model.named_modules():
|
233 |
+
if isinstance(module, torch.nn.Dropout):
|
234 |
+
module.p = opt.dropout
|
235 |
+
|
236 |
+
if torch.distributed.is_initialized():
|
237 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
238 |
+
model,
|
239 |
+
device_ids=[opt.local_rank],
|
240 |
+
output_device=opt.local_rank,
|
241 |
+
find_unused_parameters=False,
|
242 |
+
)
|
243 |
+
|
244 |
+
logger.info("Start training")
|
245 |
+
finetuning(opt, model, optimizer, scheduler, tokenizer, step)
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == "__main__":
|
249 |
+
main()
|
sentence-transformers/generate_passage_embeddings.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import csv
|
11 |
+
import logging
|
12 |
+
import pickle
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
|
17 |
+
import transformers
|
18 |
+
|
19 |
+
import src.slurm
|
20 |
+
import src.contriever
|
21 |
+
import src.utils
|
22 |
+
import src.data
|
23 |
+
import src.normalize_text
|
24 |
+
|
25 |
+
|
26 |
+
def embed_passages(args, passages, model, tokenizer):
|
27 |
+
total = 0
|
28 |
+
allids, allembeddings = [], []
|
29 |
+
batch_ids, batch_text = [], []
|
30 |
+
with torch.no_grad():
|
31 |
+
for k, p in enumerate(passages):
|
32 |
+
batch_ids.append(p["id"])
|
33 |
+
if args.no_title or not "title" in p:
|
34 |
+
text = p["text"]
|
35 |
+
else:
|
36 |
+
text = p["title"] + " " + p["text"]
|
37 |
+
if args.lowercase:
|
38 |
+
text = text.lower()
|
39 |
+
if args.normalize_text:
|
40 |
+
text = src.normalize_text.normalize(text)
|
41 |
+
batch_text.append(text)
|
42 |
+
|
43 |
+
if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1:
|
44 |
+
|
45 |
+
encoded_batch = tokenizer.batch_encode_plus(
|
46 |
+
batch_text,
|
47 |
+
return_tensors="pt",
|
48 |
+
max_length=args.passage_maxlength,
|
49 |
+
padding=True,
|
50 |
+
truncation=True,
|
51 |
+
)
|
52 |
+
|
53 |
+
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
54 |
+
embeddings = model(**encoded_batch)
|
55 |
+
|
56 |
+
embeddings = embeddings.cpu()
|
57 |
+
total += len(batch_ids)
|
58 |
+
allids.extend(batch_ids)
|
59 |
+
allembeddings.append(embeddings)
|
60 |
+
|
61 |
+
batch_text = []
|
62 |
+
batch_ids = []
|
63 |
+
if k % 100000 == 0 and k > 0:
|
64 |
+
print(f"Encoded passages {total}")
|
65 |
+
|
66 |
+
allembeddings = torch.cat(allembeddings, dim=0).numpy()
|
67 |
+
return allids, allembeddings
|
68 |
+
|
69 |
+
|
70 |
+
def main(args):
|
71 |
+
model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
|
72 |
+
print(f"Model loaded from {args.model_name_or_path}.", flush=True)
|
73 |
+
model.eval()
|
74 |
+
model = model.cuda()
|
75 |
+
if not args.no_fp16:
|
76 |
+
model = model.half()
|
77 |
+
|
78 |
+
passages = src.data.load_passages(args.passages)
|
79 |
+
|
80 |
+
shard_size = len(passages) // args.num_shards
|
81 |
+
start_idx = args.shard_id * shard_size
|
82 |
+
end_idx = start_idx + shard_size
|
83 |
+
if args.shard_id == args.num_shards - 1:
|
84 |
+
end_idx = len(passages)
|
85 |
+
|
86 |
+
passages = passages[start_idx:end_idx]
|
87 |
+
print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")
|
88 |
+
|
89 |
+
allids, allembeddings = embed_passages(args, passages, model, tokenizer)
|
90 |
+
|
91 |
+
save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}")
|
92 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
93 |
+
print(f"Saving {len(allids)} passage embeddings to {save_file}.")
|
94 |
+
with open(save_file, mode="wb") as f:
|
95 |
+
pickle.dump((allids, allembeddings), f)
|
96 |
+
|
97 |
+
print(f"Total passages processed {len(allids)}. Written to {save_file}.")
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
parser = argparse.ArgumentParser()
|
102 |
+
|
103 |
+
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
|
104 |
+
parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings")
|
105 |
+
parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings")
|
106 |
+
parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard")
|
107 |
+
parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards")
|
108 |
+
parser.add_argument(
|
109 |
+
"--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass"
|
110 |
+
)
|
111 |
+
parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage")
|
112 |
+
parser.add_argument(
|
113 |
+
"--model_name_or_path", type=str, help="path to directory containing model weights and config file"
|
114 |
+
)
|
115 |
+
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
116 |
+
parser.add_argument("--no_title", action="store_true", help="title not added to the passage body")
|
117 |
+
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
118 |
+
parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding")
|
119 |
+
|
120 |
+
args = parser.parse_args()
|
121 |
+
|
122 |
+
src.slurm.init_distributed_mode(args)
|
123 |
+
|
124 |
+
main(args)
|
sentence-transformers/index.rst
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SentenceTransformers Documentation
|
2 |
+
=================================================
|
3 |
+
|
4 |
+
SentenceTransformers is a Python framework for state-of-the-art sentence, text and image embeddings. The initial work is described in our paper `Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks <https://arxiv.org/abs/1908.10084>`_.
|
5 |
+
|
6 |
+
You can use this framework to compute sentence / text embeddings for more than 100 languages. These embeddings can then be compared e.g. with cosine-similarity to find sentences with a similar meaning. This can be useful for `semantic textual similar <docs/usage/semantic_textual_similarity.html>`_, `semantic search <examples/applications/semantic-search/README.html>`_, or `paraphrase mining <examples/applications/paraphrase-mining/README.html>`_.
|
7 |
+
|
8 |
+
The framework is based on `PyTorch <https://pytorch.org/>`_ and `Transformers <https://huggingface.co/transformers/>`_ and offers a large collection of `pre-trained models <docs/pretrained_models.html>`_ tuned for various tasks. Further, it is easy to `fine-tune your own models <docs/training/overview.html>`_.
|
9 |
+
|
10 |
+
|
11 |
+
Installation
|
12 |
+
=================================================
|
13 |
+
|
14 |
+
You can install it using pip:
|
15 |
+
|
16 |
+
.. code-block:: python
|
17 |
+
|
18 |
+
pip install -U sentence-transformers
|
19 |
+
|
20 |
+
|
21 |
+
We recommend **Python 3.6** or higher, and at least **PyTorch 1.6.0**. See `installation <docs/installation.html>`_ for further installation options, especially if you want to use a GPU.
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
Usage
|
26 |
+
=================================================
|
27 |
+
The usage is as simple as:
|
28 |
+
|
29 |
+
.. code-block:: python
|
30 |
+
|
31 |
+
from sentence_transformers import SentenceTransformer
|
32 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
33 |
+
|
34 |
+
#Our sentences we like to encode
|
35 |
+
sentences = ['This framework generates embeddings for each input sentence',
|
36 |
+
'Sentences are passed as a list of string.',
|
37 |
+
'The quick brown fox jumps over the lazy dog.']
|
38 |
+
|
39 |
+
#Sentences are encoded by calling model.encode()
|
40 |
+
embeddings = model.encode(sentences)
|
41 |
+
|
42 |
+
#Print the embeddings
|
43 |
+
for sentence, embedding in zip(sentences, embeddings):
|
44 |
+
print("Sentence:", sentence)
|
45 |
+
print("Embedding:", embedding)
|
46 |
+
print("")
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
Performance
|
52 |
+
=========================
|
53 |
+
|
54 |
+
Our models are evaluated extensively and achieve state-of-the-art performance on various tasks. Further, the code is tuned to provide the highest possible speed. Have a look at `Pre-Trained Models <https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models/>`_ for an overview of available models and the respective performance on different tasks.
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
Contact
|
62 |
+
=========================
|
63 |
+
|
64 |
+
Contact person: Nils Reimers, info@nils-reimers.de
|
65 |
+
|
66 |
+
https://www.ukp.tu-darmstadt.de/
|
67 |
+
|
68 |
+
|
69 |
+
Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.
|
70 |
+
|
71 |
+
*This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.*
|
72 |
+
|
73 |
+
|
74 |
+
Citing & Authors
|
75 |
+
=========================
|
76 |
+
|
77 |
+
If you find this repository helpful, feel free to cite our publication `Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks <https://arxiv.org/abs/1908.10084>`_:
|
78 |
+
|
79 |
+
.. code-block:: bibtex
|
80 |
+
|
81 |
+
@inproceedings{reimers-2019-sentence-bert,
|
82 |
+
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
|
83 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
84 |
+
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
|
85 |
+
month = "11",
|
86 |
+
year = "2019",
|
87 |
+
publisher = "Association for Computational Linguistics",
|
88 |
+
url = "https://arxiv.org/abs/1908.10084",
|
89 |
+
}
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
If you use one of the multilingual models, feel free to cite our publication `Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation <https://arxiv.org/abs/2004.09813>`_:
|
94 |
+
|
95 |
+
.. code-block:: bibtex
|
96 |
+
|
97 |
+
@inproceedings{reimers-2020-multilingual-sentence-bert,
|
98 |
+
title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
|
99 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
100 |
+
booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
|
101 |
+
month = "11",
|
102 |
+
year = "2020",
|
103 |
+
publisher = "Association for Computational Linguistics",
|
104 |
+
url = "https://arxiv.org/abs/2004.09813",
|
105 |
+
}
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
If you use the code for `data augmentation <https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/data_augmentation>`_, feel free to cite our publication `Augmented SBERT: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks <https://arxiv.org/abs/2010.08240>`_:
|
110 |
+
|
111 |
+
.. code-block:: bibtex
|
112 |
+
|
113 |
+
@inproceedings{thakur-2020-AugSBERT,
|
114 |
+
title = "Augmented {SBERT}: Data Augmentation Method for Improving Bi-Encoders for Pairwise Sentence Scoring Tasks",
|
115 |
+
author = "Thakur, Nandan and Reimers, Nils and Daxenberger, Johannes and Gurevych, Iryna",
|
116 |
+
booktitle = "Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies",
|
117 |
+
month = jun,
|
118 |
+
year = "2021",
|
119 |
+
address = "Online",
|
120 |
+
publisher = "Association for Computational Linguistics",
|
121 |
+
url = "https://www.aclweb.org/anthology/2021.naacl-main.28",
|
122 |
+
pages = "296--310",
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
.. toctree::
|
128 |
+
:maxdepth: 2
|
129 |
+
:caption: Overview
|
130 |
+
|
131 |
+
docs/installation
|
132 |
+
docs/quickstart
|
133 |
+
docs/pretrained_models
|
134 |
+
docs/pretrained_cross-encoders
|
135 |
+
docs/publications
|
136 |
+
docs/hugging_face
|
137 |
+
|
138 |
+
.. toctree::
|
139 |
+
:maxdepth: 2
|
140 |
+
:caption: Usage
|
141 |
+
|
142 |
+
examples/applications/computing-embeddings/README
|
143 |
+
docs/usage/semantic_textual_similarity
|
144 |
+
examples/applications/semantic-search/README
|
145 |
+
examples/applications/retrieve_rerank/README
|
146 |
+
examples/applications/clustering/README
|
147 |
+
examples/applications/paraphrase-mining/README
|
148 |
+
examples/applications/parallel-sentence-mining/README
|
149 |
+
examples/applications/cross-encoder/README
|
150 |
+
examples/applications/image-search/README
|
151 |
+
|
152 |
+
.. toctree::
|
153 |
+
:maxdepth: 2
|
154 |
+
:caption: Training
|
155 |
+
|
156 |
+
docs/training/overview
|
157 |
+
examples/training/multilingual/README
|
158 |
+
examples/training/distillation/README
|
159 |
+
examples/training/cross-encoder/README
|
160 |
+
examples/training/data_augmentation/README
|
161 |
+
|
162 |
+
.. toctree::
|
163 |
+
:maxdepth: 2
|
164 |
+
:caption: Training Examples
|
165 |
+
|
166 |
+
examples/training/sts/README
|
167 |
+
examples/training/nli/README
|
168 |
+
examples/training/paraphrases/README
|
169 |
+
examples/training/quora_duplicate_questions/README
|
170 |
+
examples/training/ms_marco/README
|
171 |
+
|
172 |
+
.. toctree::
|
173 |
+
:maxdepth: 2
|
174 |
+
:caption: Unsupervised Learning
|
175 |
+
|
176 |
+
examples/unsupervised_learning/README
|
177 |
+
examples/domain_adaptation/README
|
178 |
+
|
179 |
+
.. toctree::
|
180 |
+
:maxdepth: 1
|
181 |
+
:caption: Package Reference
|
182 |
+
|
183 |
+
docs/package_reference/SentenceTransformer
|
184 |
+
docs/package_reference/util
|
185 |
+
docs/package_reference/models
|
186 |
+
docs/package_reference/losses
|
187 |
+
docs/package_reference/evaluation
|
188 |
+
docs/package_reference/datasets
|
189 |
+
docs/package_reference/cross_encoder
|
sentence-transformers/passage_retrieval.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import csv
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
import pickle
|
13 |
+
import time
|
14 |
+
import glob
|
15 |
+
from pathlib import Path
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import transformers
|
20 |
+
|
21 |
+
import src.index
|
22 |
+
import src.contriever
|
23 |
+
import src.utils
|
24 |
+
import src.slurm
|
25 |
+
import src.data
|
26 |
+
from src.evaluation import calculate_matches
|
27 |
+
import src.normalize_text
|
28 |
+
|
29 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
30 |
+
|
31 |
+
|
32 |
+
def embed_queries(args, queries, model, tokenizer):
|
33 |
+
model.eval()
|
34 |
+
embeddings, batch_question = [], []
|
35 |
+
with torch.no_grad():
|
36 |
+
|
37 |
+
for k, q in enumerate(queries):
|
38 |
+
if args.lowercase:
|
39 |
+
q = q.lower()
|
40 |
+
if args.normalize_text:
|
41 |
+
q = src.normalize_text.normalize(q)
|
42 |
+
batch_question.append(q)
|
43 |
+
|
44 |
+
if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
|
45 |
+
|
46 |
+
encoded_batch = tokenizer.batch_encode_plus(
|
47 |
+
batch_question,
|
48 |
+
return_tensors="pt",
|
49 |
+
max_length=args.question_maxlength,
|
50 |
+
padding=True,
|
51 |
+
truncation=True,
|
52 |
+
)
|
53 |
+
encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
|
54 |
+
output = model(**encoded_batch)
|
55 |
+
embeddings.append(output.cpu())
|
56 |
+
|
57 |
+
batch_question = []
|
58 |
+
|
59 |
+
embeddings = torch.cat(embeddings, dim=0)
|
60 |
+
print(f"Questions embeddings shape: {embeddings.size()}")
|
61 |
+
|
62 |
+
return embeddings.numpy()
|
63 |
+
|
64 |
+
|
65 |
+
def index_encoded_data(index, embedding_files, indexing_batch_size):
|
66 |
+
allids = []
|
67 |
+
allembeddings = np.array([])
|
68 |
+
for i, file_path in enumerate(embedding_files):
|
69 |
+
print(f"Loading file {file_path}")
|
70 |
+
with open(file_path, "rb") as fin:
|
71 |
+
ids, embeddings = pickle.load(fin)
|
72 |
+
|
73 |
+
allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings
|
74 |
+
allids.extend(ids)
|
75 |
+
while allembeddings.shape[0] > indexing_batch_size:
|
76 |
+
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
77 |
+
|
78 |
+
while allembeddings.shape[0] > 0:
|
79 |
+
allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size)
|
80 |
+
|
81 |
+
print("Data indexing completed.")
|
82 |
+
|
83 |
+
|
84 |
+
def add_embeddings(index, embeddings, ids, indexing_batch_size):
|
85 |
+
end_idx = min(indexing_batch_size, embeddings.shape[0])
|
86 |
+
ids_toadd = ids[:end_idx]
|
87 |
+
embeddings_toadd = embeddings[:end_idx]
|
88 |
+
ids = ids[end_idx:]
|
89 |
+
embeddings = embeddings[end_idx:]
|
90 |
+
index.index_data(ids_toadd, embeddings_toadd)
|
91 |
+
return embeddings, ids
|
92 |
+
|
93 |
+
|
94 |
+
def validate(data, workers_num):
|
95 |
+
match_stats = calculate_matches(data, workers_num)
|
96 |
+
top_k_hits = match_stats.top_k_hits
|
97 |
+
|
98 |
+
print("Validation results: top k documents hits %s", top_k_hits)
|
99 |
+
top_k_hits = [v / len(data) for v in top_k_hits]
|
100 |
+
message = ""
|
101 |
+
for k in [5, 10, 20, 100]:
|
102 |
+
if k <= len(top_k_hits):
|
103 |
+
message += f"R@{k}: {top_k_hits[k-1]} "
|
104 |
+
print(message)
|
105 |
+
return match_stats.questions_doc_hits
|
106 |
+
|
107 |
+
|
108 |
+
def add_passages(data, passages, top_passages_and_scores):
|
109 |
+
# add passages to original data
|
110 |
+
merged_data = []
|
111 |
+
assert len(data) == len(top_passages_and_scores)
|
112 |
+
for i, d in enumerate(data):
|
113 |
+
results_and_scores = top_passages_and_scores[i]
|
114 |
+
docs = [passages[doc_id] for doc_id in results_and_scores[0]]
|
115 |
+
scores = [str(score) for score in results_and_scores[1]]
|
116 |
+
ctxs_num = len(docs)
|
117 |
+
d["ctxs"] = [
|
118 |
+
{
|
119 |
+
"id": results_and_scores[0][c],
|
120 |
+
"title": docs[c]["title"],
|
121 |
+
"text": docs[c]["text"],
|
122 |
+
"score": scores[c],
|
123 |
+
}
|
124 |
+
for c in range(ctxs_num)
|
125 |
+
]
|
126 |
+
|
127 |
+
|
128 |
+
def add_hasanswer(data, hasanswer):
|
129 |
+
# add hasanswer to data
|
130 |
+
for i, ex in enumerate(data):
|
131 |
+
for k, d in enumerate(ex["ctxs"]):
|
132 |
+
d["hasanswer"] = hasanswer[i][k]
|
133 |
+
|
134 |
+
|
135 |
+
def load_data(data_path):
|
136 |
+
if data_path.endswith(".json"):
|
137 |
+
with open(data_path, "r") as fin:
|
138 |
+
data = json.load(fin)
|
139 |
+
elif data_path.endswith(".jsonl"):
|
140 |
+
data = []
|
141 |
+
with open(data_path, "r") as fin:
|
142 |
+
for k, example in enumerate(fin):
|
143 |
+
example = json.loads(example)
|
144 |
+
data.append(example)
|
145 |
+
return data
|
146 |
+
|
147 |
+
|
148 |
+
def main(args):
|
149 |
+
|
150 |
+
print(f"Loading model from: {args.model_name_or_path}")
|
151 |
+
model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path)
|
152 |
+
model.eval()
|
153 |
+
model = model.cuda()
|
154 |
+
if not args.no_fp16:
|
155 |
+
model = model.half()
|
156 |
+
|
157 |
+
index = src.index.Indexer(args.projection_size, args.n_subquantizers, args.n_bits)
|
158 |
+
|
159 |
+
# index all passages
|
160 |
+
input_paths = glob.glob(args.passages_embeddings)
|
161 |
+
input_paths = sorted(input_paths)
|
162 |
+
embeddings_dir = os.path.dirname(input_paths[0])
|
163 |
+
index_path = os.path.join(embeddings_dir, "index.faiss")
|
164 |
+
if args.save_or_load_index and os.path.exists(index_path):
|
165 |
+
index.deserialize_from(embeddings_dir)
|
166 |
+
else:
|
167 |
+
print(f"Indexing passages from files {input_paths}")
|
168 |
+
start_time_indexing = time.time()
|
169 |
+
index_encoded_data(index, input_paths, args.indexing_batch_size)
|
170 |
+
print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
|
171 |
+
if args.save_or_load_index:
|
172 |
+
index.serialize(embeddings_dir)
|
173 |
+
|
174 |
+
# load passages
|
175 |
+
passages = src.data.load_passages(args.passages)
|
176 |
+
passage_id_map = {x["id"]: x for x in passages}
|
177 |
+
|
178 |
+
data_paths = glob.glob(args.data)
|
179 |
+
alldata = []
|
180 |
+
for path in data_paths:
|
181 |
+
data = load_data(path)
|
182 |
+
output_path = os.path.join(args.output_dir, os.path.basename(path))
|
183 |
+
|
184 |
+
queries = [ex["question"] for ex in data]
|
185 |
+
questions_embedding = embed_queries(args, queries, model, tokenizer)
|
186 |
+
|
187 |
+
# get top k results
|
188 |
+
start_time_retrieval = time.time()
|
189 |
+
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
|
190 |
+
print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
|
191 |
+
|
192 |
+
add_passages(data, passage_id_map, top_ids_and_scores)
|
193 |
+
hasanswer = validate(data, args.validation_workers)
|
194 |
+
add_hasanswer(data, hasanswer)
|
195 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
196 |
+
with open(output_path, "w") as fout:
|
197 |
+
for ex in data:
|
198 |
+
json.dump(ex, fout, ensure_ascii=False)
|
199 |
+
fout.write("\n")
|
200 |
+
print(f"Saved results to {output_path}")
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
parser = argparse.ArgumentParser()
|
205 |
+
|
206 |
+
parser.add_argument(
|
207 |
+
"--data",
|
208 |
+
required=True,
|
209 |
+
type=str,
|
210 |
+
default=None,
|
211 |
+
help=".json file containing question and answers, similar format to reader data",
|
212 |
+
)
|
213 |
+
parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
|
214 |
+
parser.add_argument("--passages_embeddings", type=str, default=None, help="Glob path to encoded passages")
|
215 |
+
parser.add_argument(
|
216 |
+
"--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix"
|
217 |
+
)
|
218 |
+
parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions")
|
219 |
+
parser.add_argument(
|
220 |
+
"--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
|
221 |
+
)
|
222 |
+
parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding")
|
223 |
+
parser.add_argument(
|
224 |
+
"--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
|
225 |
+
)
|
226 |
+
parser.add_argument(
|
227 |
+
"--model_name_or_path", type=str, help="path to directory containing model weights and config file"
|
228 |
+
)
|
229 |
+
parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
|
230 |
+
parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question")
|
231 |
+
parser.add_argument(
|
232 |
+
"--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
|
233 |
+
)
|
234 |
+
parser.add_argument("--projection_size", type=int, default=768)
|
235 |
+
parser.add_argument(
|
236 |
+
"--n_subquantizers",
|
237 |
+
type=int,
|
238 |
+
default=0,
|
239 |
+
help="Number of subquantizer used for vector quantization, if 0 flat index is used",
|
240 |
+
)
|
241 |
+
parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer")
|
242 |
+
parser.add_argument("--lang", nargs="+")
|
243 |
+
parser.add_argument("--dataset", type=str, default="none")
|
244 |
+
parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
|
245 |
+
parser.add_argument("--normalize_text", action="store_true", help="normalize text")
|
246 |
+
|
247 |
+
args = parser.parse_args()
|
248 |
+
src.slurm.init_distributed_mode(args)
|
249 |
+
main(args)
|
sentence-transformers/preprocess.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import transformers
|
8 |
+
from src.normalize_text import normalize
|
9 |
+
|
10 |
+
|
11 |
+
def save(tensor, split_path):
|
12 |
+
if not os.path.exists(os.path.dirname(split_path)):
|
13 |
+
os.makedirs(os.path.dirname(split_path))
|
14 |
+
with open(split_path, 'wb') as fout:
|
15 |
+
torch.save(tensor, fout)
|
16 |
+
|
17 |
+
def apply_tokenizer(path, tokenizer, normalize_text=False):
|
18 |
+
alltokens = []
|
19 |
+
lines = []
|
20 |
+
with open(path, "r", encoding="utf-8") as fin:
|
21 |
+
for k, line in enumerate(fin):
|
22 |
+
if normalize_text:
|
23 |
+
line = normalize(line)
|
24 |
+
|
25 |
+
lines.append(line)
|
26 |
+
if len(lines) > 1000000:
|
27 |
+
tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids']
|
28 |
+
tokens = [torch.tensor(x, dtype=torch.int) for x in tokens]
|
29 |
+
alltokens.extend(tokens)
|
30 |
+
lines = []
|
31 |
+
|
32 |
+
tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids']
|
33 |
+
tokens = [torch.tensor(x, dtype=torch.int) for x in tokens]
|
34 |
+
alltokens.extend(tokens)
|
35 |
+
|
36 |
+
alltokens = torch.cat(alltokens)
|
37 |
+
return alltokens
|
38 |
+
|
39 |
+
def tokenize_file(args):
|
40 |
+
filename = os.path.basename(args.datapath)
|
41 |
+
savepath = os.path.join(args.outdir, f"{filename}.pkl")
|
42 |
+
if os.path.exists(savepath):
|
43 |
+
if args.overwrite:
|
44 |
+
print(f"File {savepath} already exists, overwriting")
|
45 |
+
else:
|
46 |
+
print(f"File {savepath} already exists, exiting")
|
47 |
+
return
|
48 |
+
try:
|
49 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=True)
|
50 |
+
except:
|
51 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=False)
|
52 |
+
print(f"Encoding {args.datapath}...")
|
53 |
+
tokens = apply_tokenizer(args.datapath, tokenizer, normalize_text=args.normalize_text)
|
54 |
+
|
55 |
+
print(f"Saving at {savepath}...")
|
56 |
+
save(tokens, savepath)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
61 |
+
parser.add_argument("--datapath", type=str)
|
62 |
+
parser.add_argument("--outdir", type=str)
|
63 |
+
parser.add_argument("--tokenizer", type=str)
|
64 |
+
parser.add_argument("--overwrite", action="store_true")
|
65 |
+
parser.add_argument("--normalize_text", action="store_true")
|
66 |
+
|
67 |
+
args, _ = parser.parse_known_args()
|
68 |
+
tokenize_file(args)
|
sentence-transformers/requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers>=4.6.0,<5.0.0
|
2 |
+
tokenizers>=0.10.3
|
3 |
+
tqdm
|
4 |
+
torch>=1.6.0
|
5 |
+
torchvision
|
6 |
+
numpy
|
7 |
+
scikit-learn
|
8 |
+
scipy
|
9 |
+
nltk
|
10 |
+
sentencepiece
|
11 |
+
huggingface-hub
|
sentence-transformers/setup.cfg
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[metadata]
|
2 |
+
description-file = README.md
|
sentence-transformers/setup.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
with open("README.md", mode="r", encoding="utf-8") as readme_file:
|
4 |
+
readme = readme_file.read()
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
setup(
|
9 |
+
name="sentence-transformers",
|
10 |
+
version="2.2.2",
|
11 |
+
author="Nils Reimers",
|
12 |
+
author_email="info@nils-reimers.de",
|
13 |
+
description="Multilingual text embeddings",
|
14 |
+
long_description=readme,
|
15 |
+
long_description_content_type="text/markdown",
|
16 |
+
license="Apache License 2.0",
|
17 |
+
url="https://www.SBERT.net",
|
18 |
+
download_url="https://github.com/UKPLab/sentence-transformers/",
|
19 |
+
packages=find_packages(),
|
20 |
+
python_requires=">=3.6.0",
|
21 |
+
install_requires=[
|
22 |
+
'transformers>=4.6.0,<5.0.0',
|
23 |
+
'tqdm',
|
24 |
+
'torch>=1.6.0',
|
25 |
+
'torchvision',
|
26 |
+
'numpy',
|
27 |
+
'scikit-learn',
|
28 |
+
'scipy',
|
29 |
+
'nltk',
|
30 |
+
'sentencepiece',
|
31 |
+
'huggingface-hub>=0.4.0'
|
32 |
+
],
|
33 |
+
classifiers=[
|
34 |
+
"Development Status :: 5 - Production/Stable",
|
35 |
+
"Intended Audience :: Science/Research",
|
36 |
+
"License :: OSI Approved :: Apache Software License",
|
37 |
+
"Programming Language :: Python :: 3.6",
|
38 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
39 |
+
],
|
40 |
+
keywords="Transformer Networks BERT XLNet sentence embedding PyTorch NLP deep learning"
|
41 |
+
)
|
sentence-transformers/train.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import sys
|
6 |
+
import torch
|
7 |
+
import logging
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
import random
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
import torch.distributed as dist
|
14 |
+
from torch.utils.data import DataLoader, RandomSampler
|
15 |
+
|
16 |
+
from src.options import Options
|
17 |
+
from src import data, beir_utils, slurm, dist_utils, utils
|
18 |
+
from src import moco, inbatch
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def train(opt, model, optimizer, scheduler, step):
|
25 |
+
|
26 |
+
run_stats = utils.WeightedAvgStats()
|
27 |
+
|
28 |
+
tb_logger = utils.init_tb_logger(opt.output_dir)
|
29 |
+
|
30 |
+
logger.info("Data loading")
|
31 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
32 |
+
tokenizer = model.module.tokenizer
|
33 |
+
else:
|
34 |
+
tokenizer = model.tokenizer
|
35 |
+
collator = data.Collator(opt=opt)
|
36 |
+
train_dataset = data.load_data(opt, tokenizer)
|
37 |
+
logger.warning(f"Data loading finished for rank {dist_utils.get_rank()}")
|
38 |
+
|
39 |
+
train_sampler = RandomSampler(train_dataset)
|
40 |
+
train_dataloader = DataLoader(
|
41 |
+
train_dataset,
|
42 |
+
sampler=train_sampler,
|
43 |
+
batch_size=opt.per_gpu_batch_size,
|
44 |
+
drop_last=True,
|
45 |
+
num_workers=opt.num_workers,
|
46 |
+
collate_fn=collator,
|
47 |
+
)
|
48 |
+
|
49 |
+
epoch = 1
|
50 |
+
|
51 |
+
model.train()
|
52 |
+
while step < opt.total_steps:
|
53 |
+
train_dataset.generate_offset()
|
54 |
+
|
55 |
+
logger.info(f"Start epoch {epoch}")
|
56 |
+
for i, batch in enumerate(train_dataloader):
|
57 |
+
step += 1
|
58 |
+
|
59 |
+
batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
|
60 |
+
train_loss, iter_stats = model(**batch, stats_prefix="train")
|
61 |
+
|
62 |
+
train_loss.backward()
|
63 |
+
optimizer.step()
|
64 |
+
|
65 |
+
scheduler.step()
|
66 |
+
model.zero_grad()
|
67 |
+
|
68 |
+
run_stats.update(iter_stats)
|
69 |
+
|
70 |
+
if step % opt.log_freq == 0:
|
71 |
+
log = f"{step} / {opt.total_steps}"
|
72 |
+
for k, v in sorted(run_stats.average_stats.items()):
|
73 |
+
log += f" | {k}: {v:.3f}"
|
74 |
+
if tb_logger:
|
75 |
+
tb_logger.add_scalar(k, v, step)
|
76 |
+
log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}"
|
77 |
+
log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB"
|
78 |
+
|
79 |
+
logger.info(log)
|
80 |
+
run_stats.reset()
|
81 |
+
|
82 |
+
if step % opt.eval_freq == 0:
|
83 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
84 |
+
encoder = model.module.get_encoder()
|
85 |
+
else:
|
86 |
+
encoder = model.get_encoder()
|
87 |
+
eval_model(
|
88 |
+
opt, query_encoder=encoder, doc_encoder=encoder, tokenizer=tokenizer, tb_logger=tb_logger, step=step
|
89 |
+
)
|
90 |
+
|
91 |
+
if dist_utils.is_main():
|
92 |
+
utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"lastlog")
|
93 |
+
|
94 |
+
model.train()
|
95 |
+
|
96 |
+
if dist_utils.is_main() and step % opt.save_freq == 0:
|
97 |
+
utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"step-{step}")
|
98 |
+
|
99 |
+
if step > opt.total_steps:
|
100 |
+
break
|
101 |
+
epoch += 1
|
102 |
+
|
103 |
+
|
104 |
+
def eval_model(opt, query_encoder, doc_encoder, tokenizer, tb_logger, step):
|
105 |
+
for datasetname in opt.eval_datasets:
|
106 |
+
metrics = beir_utils.evaluate_model(
|
107 |
+
query_encoder,
|
108 |
+
doc_encoder,
|
109 |
+
tokenizer,
|
110 |
+
dataset=datasetname,
|
111 |
+
batch_size=opt.per_gpu_eval_batch_size,
|
112 |
+
norm_doc=opt.norm_doc,
|
113 |
+
norm_query=opt.norm_query,
|
114 |
+
beir_dir=opt.eval_datasets_dir,
|
115 |
+
score_function=opt.score_function,
|
116 |
+
lower_case=opt.lower_case,
|
117 |
+
normalize_text=opt.eval_normalize_text,
|
118 |
+
)
|
119 |
+
|
120 |
+
message = []
|
121 |
+
if dist_utils.is_main():
|
122 |
+
for metric in ["NDCG@10", "Recall@10", "Recall@100"]:
|
123 |
+
message.append(f"{datasetname}/{metric}: {metrics[metric]:.2f}")
|
124 |
+
if tb_logger is not None:
|
125 |
+
tb_logger.add_scalar(f"{datasetname}/{metric}", metrics[metric], step)
|
126 |
+
logger.info(" | ".join(message))
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
logger.info("Start")
|
131 |
+
|
132 |
+
options = Options()
|
133 |
+
opt = options.parse()
|
134 |
+
|
135 |
+
torch.manual_seed(opt.seed)
|
136 |
+
slurm.init_distributed_mode(opt)
|
137 |
+
slurm.init_signal_handler()
|
138 |
+
|
139 |
+
directory_exists = os.path.isdir(opt.output_dir)
|
140 |
+
if dist.is_initialized():
|
141 |
+
dist.barrier()
|
142 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
143 |
+
if not directory_exists and dist_utils.is_main():
|
144 |
+
options.print_options(opt)
|
145 |
+
if dist.is_initialized():
|
146 |
+
dist.barrier()
|
147 |
+
utils.init_logger(opt)
|
148 |
+
|
149 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
150 |
+
|
151 |
+
if opt.contrastive_mode == "moco":
|
152 |
+
model_class = moco.MoCo
|
153 |
+
elif opt.contrastive_mode == "inbatch":
|
154 |
+
model_class = inbatch.InBatch
|
155 |
+
else:
|
156 |
+
raise ValueError(f"contrastive mode: {opt.contrastive_mode} not recognised")
|
157 |
+
|
158 |
+
if not directory_exists and opt.model_path == "none":
|
159 |
+
model = model_class(opt)
|
160 |
+
model = model.cuda()
|
161 |
+
optimizer, scheduler = utils.set_optim(opt, model)
|
162 |
+
step = 0
|
163 |
+
elif directory_exists:
|
164 |
+
model_path = os.path.join(opt.output_dir, "checkpoint", "latest")
|
165 |
+
model, optimizer, scheduler, opt_checkpoint, step = utils.load(
|
166 |
+
model_class,
|
167 |
+
model_path,
|
168 |
+
opt,
|
169 |
+
reset_params=False,
|
170 |
+
)
|
171 |
+
logger.info(f"Model loaded from {opt.output_dir}")
|
172 |
+
else:
|
173 |
+
model, optimizer, scheduler, opt_checkpoint, step = utils.load(
|
174 |
+
model_class,
|
175 |
+
opt.model_path,
|
176 |
+
opt,
|
177 |
+
reset_params=False if opt.continue_training else True,
|
178 |
+
)
|
179 |
+
if not opt.continue_training:
|
180 |
+
step = 0
|
181 |
+
logger.info(f"Model loaded from {opt.model_path}")
|
182 |
+
|
183 |
+
logger.info(utils.get_parameters(model))
|
184 |
+
|
185 |
+
if dist.is_initialized():
|
186 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
187 |
+
model,
|
188 |
+
device_ids=[opt.local_rank],
|
189 |
+
output_device=opt.local_rank,
|
190 |
+
find_unused_parameters=False,
|
191 |
+
)
|
192 |
+
dist.barrier()
|
193 |
+
|
194 |
+
logger.info("Start training")
|
195 |
+
train(opt, model, optimizer, scheduler, step)
|