Sentence Similarity
sentence-transformers
PyTorch
English
bert
feature-extraction
mteb
custom_code
Eval Results
text-embeddings-inference
6 papers
dylanAtHum commited on
Commit
64ae4c7
1 Parent(s): 55fdd33

Initial Commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. 1_Pooling/config.json +4 -0
  2. Data_Records.ipynb +92 -0
  3. Dataloading.ipynb +675 -0
  4. README.md +1937 -0
  5. Replication.txt +40 -0
  6. Training.py +465 -0
  7. bert_layers.py +1072 -0
  8. bert_padding.py +159 -0
  9. config.json +33 -0
  10. config_sentence_transformers.json +7 -0
  11. configuration_bert.py +25 -0
  12. data_records.json +1 -0
  13. flash_attn_triton.py +1112 -0
  14. modules.json +20 -0
  15. mteb_results/AmazonCounterfactualClassification.json +29 -0
  16. mteb_results/AmazonPolarityClassification.json +15 -0
  17. mteb_results/AmazonReviewsClassification.json +25 -0
  18. mteb_results/ArguAna.json +38 -0
  19. mteb_results/ArxivClusteringP2P.json +10 -0
  20. mteb_results/ArxivClusteringS2S.json +10 -0
  21. mteb_results/AskUbuntuDupQuestions.json +10 -0
  22. mteb_results/BIOSSES.json +20 -0
  23. mteb_results/Banking77Classification.json +13 -0
  24. mteb_results/BiorxivClusteringP2P.json +10 -0
  25. mteb_results/BiorxivClusteringS2S.json +10 -0
  26. mteb_results/CQADupstackEnglishRetrieval.json +38 -0
  27. mteb_results/ClimateFEVER.json +38 -0
  28. mteb_results/DBPedia.json +38 -0
  29. mteb_results/EmotionClassification.json +21 -0
  30. mteb_results/FEVER.json +38 -0
  31. mteb_results/FiQA2018.json +38 -0
  32. mteb_results/HotpotQA.json +38 -0
  33. mteb_results/ImdbClassification.json +15 -0
  34. mteb_results/MSMARCO.json +38 -0
  35. mteb_results/MTOPDomainClassification.json +25 -0
  36. mteb_results/MTOPIntentClassification.json +25 -0
  37. mteb_results/MassiveIntentClassification.json +25 -0
  38. mteb_results/MassiveScenarioClassification.json +25 -0
  39. mteb_results/MedrxivClusteringP2P.json +10 -0
  40. mteb_results/MedrxivClusteringS2S.json +10 -0
  41. mteb_results/MindSmallReranking.json +10 -0
  42. mteb_results/NFCorpus.json +38 -0
  43. mteb_results/NQ.json +38 -0
  44. mteb_results/QuoraRetrieval.json +38 -0
  45. mteb_results/RedditClustering.json +10 -0
  46. mteb_results/RedditClusteringP2P.json +10 -0
  47. mteb_results/SCIDOCS.json +38 -0
  48. mteb_results/SICK-R.json +20 -0
  49. mteb_results/STS12.json +20 -0
  50. mteb_results/STS13.json +20 -0
1_Pooling/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_mean_tokens": true
4
+ }
Data_Records.ipynb ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e66bbb77-71f5-4d80-b766-f67144ea7a93",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Data Records\n",
9
+ "\n",
10
+ "## This notebook generates the data_records.json file where each entry in the resulting dictionary follows the form {filename: num_records} for every dataset we will use during training"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 39,
16
+ "id": "74ad6613-44ff-435e-8550-df993e915677",
17
+ "metadata": {
18
+ "tags": []
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "# import relevant libraries\n",
23
+ "import os\n",
24
+ "import boto3\n",
25
+ "import json\n",
26
+ "from smart_open import open"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "e2d53761-da0e-44f4-8a3e-1285bf810b03",
33
+ "metadata": {
34
+ "tags": []
35
+ },
36
+ "outputs": [],
37
+ "source": [
38
+ "s3 = boto3.resource('s3')\n",
39
+ "my_bucket = s3.Bucket('lodestone-rnd')\n",
40
+ "\n",
41
+ "# collect all filenames from the data/ directory of the lodestone-rnd S3 bucket\n",
42
+ "files = [\"\"]*((621+12+9+36)+1)\n",
43
+ "for i, object_summary in enumerate(my_bucket.objects.filter(Prefix=\"data/\")):\n",
44
+ " files[i] = object_summary.key[5:]\n",
45
+ "files = files[1:]\n",
46
+ "files = [file for file in files if file != 'cnn_dailymail_splitted.json.gz']\n",
47
+ "\n",
48
+ "s3_client = boto3.client(\"s3\")\n",
49
+ "\n",
50
+ "# for each training dataset, store the number of records in a dictionary with the following form {filename: num_records}\n",
51
+ "data_lengths = {}\n",
52
+ "for file in files:\n",
53
+ " source_uri = f's3://lodestone-rnd/data/{file}'\n",
54
+ " # S2ORC_citations_abstracts.json.gz and amazon-qa.json.gz must be handled differently since each line in their training\n",
55
+ " # data is split into multiple records due to the fact that each query has multiple positive pair responses\n",
56
+ " if file in ['S2ORC_citations_abstracts.json.gz','amazon-qa.json.gz']:\n",
57
+ " length = 0\n",
58
+ " for json_line in open(source_uri, transport_params={\"client\": s3_client}):\n",
59
+ " data = json.loads(json_line.strip())\n",
60
+ " length += len(data['pos'])\n",
61
+ " else:\n",
62
+ " length = int(os.popen(f'aws s3 cp {source_uri} - | zcat | wc -l').read().rstrip())\n",
63
+ " data_lengths[f'{file}'] = length\n",
64
+ " \n",
65
+ "# write the resulting dictionary to a .json file for future use during training\n",
66
+ "with open('data_records.json', 'w') as fileout:\n",
67
+ " json.dump(data_lengths, fileout)"
68
+ ]
69
+ }
70
+ ],
71
+ "metadata": {
72
+ "kernelspec": {
73
+ "display_name": "conda_pytorch_p310",
74
+ "language": "python",
75
+ "name": "conda_pytorch_p310"
76
+ },
77
+ "language_info": {
78
+ "codemirror_mode": {
79
+ "name": "ipython",
80
+ "version": 3
81
+ },
82
+ "file_extension": ".py",
83
+ "mimetype": "text/x-python",
84
+ "name": "python",
85
+ "nbconvert_exporter": "python",
86
+ "pygments_lexer": "ipython3",
87
+ "version": "3.10.10"
88
+ }
89
+ },
90
+ "nbformat": 4,
91
+ "nbformat_minor": 5
92
+ }
Dataloading.ipynb ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "3f9ce240-fd1a-4550-83c7-8cf9658b1d3a",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Dataloading (1B+ Training Pairs)\n",
9
+ "\n",
10
+ "## This notebook collects and uploads all 50 relevant sentence embedding datasets to S3 as .json.gz files where each line contains one training record"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 62,
16
+ "id": "d7af8d0e-b3ed-4007-abdd-5952d775e119",
17
+ "metadata": {
18
+ "tags": []
19
+ },
20
+ "outputs": [],
21
+ "source": [
22
+ "import os\n",
23
+ "import pandas as pd\n",
24
+ "\n",
25
+ "os.chdir('/home/ec2-user/SageMaker')"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 63,
31
+ "id": "686699fb-d2c9-4653-afca-580aef343451",
32
+ "metadata": {
33
+ "collapsed": true,
34
+ "jupyter": {
35
+ "outputs_hidden": true
36
+ },
37
+ "tags": []
38
+ },
39
+ "outputs": [
40
+ {
41
+ "name": "stdout",
42
+ "output_type": "stream",
43
+ "text": [
44
+ "Loaded plugins: dkms-build-requires, extras_suggestions, langpacks, priorities,\n",
45
+ " : update-motd, versionlock\n",
46
+ "Cleaning repos: amzn2-core amzn2extra-docker amzn2extra-epel\n",
47
+ " : amzn2extra-kernel-5.10 amzn2extra-python3.8 centos-extras\n",
48
+ " : copr:copr.fedorainfracloud.org:vbatts:shadow-utils-newxidmap\n",
49
+ " : docker-ce-stable libnvidia-container neuron\n",
50
+ "21 metadata files removed\n",
51
+ "15 sqlite files removed\n",
52
+ "0 metadata files removed\n",
53
+ "Loaded plugins: dkms-build-requires, extras_suggestions, langpacks, priorities,\n",
54
+ " : update-motd, versionlock\n"
55
+ ]
56
+ },
57
+ {
58
+ "name": "stderr",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "https://download.docker.com/linux/centos/2/x86_64/stable/repodata/repomd.xml: [Errno 14] HTTPS Error 404 - Not Found\n",
62
+ "Trying other mirror.\n"
63
+ ]
64
+ },
65
+ {
66
+ "name": "stdout",
67
+ "output_type": "stream",
68
+ "text": [
69
+ "62 packages excluded due to repository priority protections\n",
70
+ "Resolving Dependencies\n",
71
+ "--> Running transaction check\n",
72
+ "---> Package epel-release.noarch 0:7-11 will be installed\n",
73
+ "--> Finished Dependency Resolution\n",
74
+ "\n",
75
+ "Dependencies Resolved\n",
76
+ "\n",
77
+ "================================================================================\n",
78
+ " Package Arch Version Repository Size\n",
79
+ "================================================================================\n",
80
+ "Installing:\n",
81
+ " epel-release noarch 7-11 amzn2extra-epel 15 k\n",
82
+ "\n",
83
+ "Transaction Summary\n",
84
+ "================================================================================\n",
85
+ "Install 1 Package\n",
86
+ "\n",
87
+ "Total download size: 15 k\n",
88
+ "Installed size: 24 k\n",
89
+ "Downloading packages:\n",
90
+ "Running transaction check\n",
91
+ "Running transaction test\n",
92
+ "Transaction test succeeded\n",
93
+ "Running transaction\n"
94
+ ]
95
+ },
96
+ {
97
+ "name": "stderr",
98
+ "output_type": "stream",
99
+ "text": [
100
+ "Warning: RPMDB altered outside of yum.\n"
101
+ ]
102
+ },
103
+ {
104
+ "name": "stdout",
105
+ "output_type": "stream",
106
+ "text": [
107
+ " Installing : epel-release-7-11.noarch 1/1 \n",
108
+ " Verifying : epel-release-7-11.noarch 1/1 \n",
109
+ "\n",
110
+ "Installed:\n",
111
+ " epel-release.noarch 0:7-11 \n",
112
+ "\n",
113
+ "Complete!\n",
114
+ "Installing epel-release\n",
115
+ " 0 ansible2 available \\\n",
116
+ " [ =2.4.2 =2.4.6 =2.8 =stable ]\n",
117
+ " 2 httpd_modules available [ =1.0 =stable ]\n",
118
+ " 3 memcached1.5 available \\\n",
119
+ " [ =1.5.1 =1.5.16 =1.5.17 ]\n",
120
+ " 6 postgresql10 available [ =10 =stable ]\n",
121
+ " 9 R3.4 available [ =3.4.3 =stable ]\n",
122
+ " 10 rust1 available \\\n",
123
+ " [ =1.22.1 =1.26.0 =1.26.1 =1.27.2 =1.31.0 =1.38.0\n",
124
+ " =stable ]\n",
125
+ " 18 libreoffice available \\\n",
126
+ " [ =5.0.6.2_15 =5.3.6.1 =stable ]\n",
127
+ " 19 gimp available [ =2.8.22 ]\n",
128
+ " 20 docker=latest enabled \\\n",
129
+ " [ =17.12.1 =18.03.1 =18.06.1 =18.09.9 =stable ]\n",
130
+ " 21 mate-desktop1.x available \\\n",
131
+ " [ =1.19.0 =1.20.0 =stable ]\n",
132
+ " 22 GraphicsMagick1.3 available \\\n",
133
+ " [ =1.3.29 =1.3.32 =1.3.34 =stable ]\n",
134
+ " 23 tomcat8.5 available \\\n",
135
+ " [ =8.5.31 =8.5.32 =8.5.38 =8.5.40 =8.5.42 =8.5.50\n",
136
+ " =stable ]\n",
137
+ " 24 epel=latest enabled [ =7.11 =stable ]\n",
138
+ " 25 testing available [ =1.0 =stable ]\n",
139
+ " 26 ecs available [ =stable ]\n",
140
+ " 27 corretto8 available \\\n",
141
+ " [ =1.8.0_192 =1.8.0_202 =1.8.0_212 =1.8.0_222 =1.8.0_232\n",
142
+ " =1.8.0_242 =stable ]\n",
143
+ " 29 golang1.11 available \\\n",
144
+ " [ =1.11.3 =1.11.11 =1.11.13 =stable ]\n",
145
+ " 30 squid4 available [ =4 =stable ]\n",
146
+ " 32 lustre2.10 available \\\n",
147
+ " [ =2.10.5 =2.10.8 =stable ]\n",
148
+ " 33 java-openjdk11 available [ =11 =stable ]\n",
149
+ " 34 lynis available [ =stable ]\n",
150
+ " 36 BCC available [ =0.x =stable ]\n",
151
+ " 37 mono available [ =5.x =stable ]\n",
152
+ " 38 nginx1 available [ =stable ]\n",
153
+ " 40 mock available [ =stable ]\n",
154
+ " 41 postgresql11 available [ =11 =stable ]\n",
155
+ " 43 livepatch available [ =stable ]\n",
156
+ " 44 python3.8=latest enabled [ =stable ]\n",
157
+ " 45 haproxy2 available [ =stable ]\n",
158
+ " 46 collectd available [ =stable ]\n",
159
+ " 47 aws-nitro-enclaves-cli available [ =stable ]\n",
160
+ " 48 R4 available [ =stable ]\n",
161
+ " _ kernel-5.4 available [ =stable ]\n",
162
+ " 50 selinux-ng available [ =stable ]\n",
163
+ " 51 php8.0 available [ =stable ]\n",
164
+ " 52 tomcat9 available [ =stable ]\n",
165
+ " 53 unbound1.13 available [ =stable ]\n",
166
+ " 54 mariadb10.5 available [ =stable ]\n",
167
+ " 55 kernel-5.10=latest enabled [ =stable ]\n",
168
+ " 56 redis6 available [ =stable ]\n",
169
+ " 57 ruby3.0 available [ =stable ]\n",
170
+ " 58 postgresql12 available [ =stable ]\n",
171
+ " 59 postgresql13 available [ =stable ]\n",
172
+ " 60 mock2 available [ =stable ]\n",
173
+ " 61 dnsmasq2.85 available [ =stable ]\n",
174
+ " 62 kernel-5.15 available [ =stable ]\n",
175
+ " 63 postgresql14 available [ =stable ]\n",
176
+ " 64 firefox available [ =stable ]\n",
177
+ " 65 lustre available [ =stable ]\n",
178
+ " 66 php8.1 available [ =stable ]\n",
179
+ " 67 awscli1 available [ =stable ]\n",
180
+ " 68 php8.2 available [ =stable ]\n",
181
+ " 69 dnsmasq available [ =stable ]\n",
182
+ " 70 unbound1.17 available [ =stable ]\n",
183
+ " 71 golang1.19 available [ =stable ]\n",
184
+ " 72 collectd-python3 available [ =stable ]\n",
185
+ "Loaded plugins: dkms-build-requires, extras_suggestions, langpacks, priorities,\n",
186
+ " : update-motd, versionlock\n",
187
+ "================================== repo: epel ==================================\n",
188
+ "[epel]\n",
189
+ "async = True\n",
190
+ "bandwidth = 0\n",
191
+ "base_persistdir = /var/lib/yum/repos/x86_64/2\n",
192
+ "baseurl = \n",
193
+ "cache = 0\n",
194
+ "cachedir = /var/cache/yum/x86_64/2/epel\n",
195
+ "check_config_file_age = True\n",
196
+ "compare_providers_priority = 80\n",
197
+ "cost = 1000\n",
198
+ "deltarpm_metadata_percentage = 100\n",
199
+ "deltarpm_percentage = \n",
200
+ "enabled = True\n",
201
+ "enablegroups = True\n",
202
+ "exclude = \n",
203
+ "failovermethod = priority\n",
204
+ "ftp_disable_epsv = False\n",
205
+ "gpgcadir = /var/lib/yum/repos/x86_64/2/epel/gpgcadir\n",
206
+ "gpgcakey = \n",
207
+ "gpgcheck = True\n",
208
+ "gpgdir = /var/lib/yum/repos/x86_64/2/epel/gpgdir\n",
209
+ "gpgkey = file:///etc/pki/rpm-gpg/RPM-GPG-KEY-EPEL-7\n",
210
+ "hdrdir = /var/cache/yum/x86_64/2/epel/headers\n",
211
+ "http_caching = all\n",
212
+ "includepkgs = \n",
213
+ "ip_resolve = \n",
214
+ "keepalive = True\n",
215
+ "keepcache = False\n",
216
+ "mddownloadpolicy = sqlite\n",
217
+ "mdpolicy = group:small\n",
218
+ "mediaid = \n",
219
+ "metadata_expire = 21600\n",
220
+ "metadata_expire_filter = read-only:present\n",
221
+ "metalink = https://mirrors.fedoraproject.org/metalink?repo=epel-7&arch=x86_64\n",
222
+ "minrate = 0\n",
223
+ "mirrorlist = \n",
224
+ "mirrorlist_expire = 86400\n",
225
+ "name = Extra Packages for Enterprise Linux 7 - x86_64\n",
226
+ "old_base_cache_dir = \n",
227
+ "password = \n",
228
+ "persistdir = /var/lib/yum/repos/x86_64/2/epel\n",
229
+ "pkgdir = /var/cache/yum/x86_64/2/epel/packages\n",
230
+ "priority = 99\n",
231
+ "proxy = False\n",
232
+ "proxy_dict = \n",
233
+ "proxy_password = \n",
234
+ "proxy_username = \n",
235
+ "repo_gpgcheck = False\n",
236
+ "report_instanceid = False\n",
237
+ "retries = 7\n",
238
+ "skip_if_unavailable = False\n",
239
+ "ssl_check_cert_permissions = True\n",
240
+ "sslcacert = \n",
241
+ "sslclientcert = \n",
242
+ "sslclientkey = \n",
243
+ "sslverify = True\n",
244
+ "throttle = 0\n",
245
+ "timeout = 5.0\n",
246
+ "ui_id = epel/x86_64\n",
247
+ "ui_repoid_vars = releasever,\n",
248
+ " basearch\n",
249
+ "username = \n",
250
+ "\n",
251
+ "Loaded plugins: dkms-build-requires, extras_suggestions, langpacks, priorities,\n",
252
+ " : update-motd, versionlock\n"
253
+ ]
254
+ },
255
+ {
256
+ "name": "stderr",
257
+ "output_type": "stream",
258
+ "text": [
259
+ "https://download.docker.com/linux/centos/2/x86_64/stable/repodata/repomd.xml: [Errno 14] HTTPS Error 404 - Not Found\n",
260
+ "Trying other mirror.\n",
261
+ "http://mirror.es.its.nyu.edu/epel/7/x86_64/repodata/repomd.xml: [Errno 12] Timeout on http://mirror.es.its.nyu.edu/epel/7/x86_64/repodata/repomd.xml: (28, 'Failed to connect to mirror.es.its.nyu.edu port 80 after 5001 ms: Timeout was reached')\n",
262
+ "Trying other mirror.\n"
263
+ ]
264
+ },
265
+ {
266
+ "name": "stdout",
267
+ "output_type": "stream",
268
+ "text": [
269
+ "286 packages excluded due to repository priority protections\n",
270
+ "Resolving Dependencies\n",
271
+ "--> Running transaction check\n",
272
+ "---> Package git-lfs.x86_64 0:2.10.0-2.el7 will be installed\n",
273
+ "--> Finished Dependency Resolution\n",
274
+ "\n",
275
+ "Dependencies Resolved\n",
276
+ "\n",
277
+ "================================================================================\n",
278
+ " Package Arch Version Repository Size\n",
279
+ "================================================================================\n",
280
+ "Installing:\n",
281
+ " git-lfs x86_64 2.10.0-2.el7 epel 3.7 M\n",
282
+ "\n",
283
+ "Transaction Summary\n",
284
+ "================================================================================\n",
285
+ "Install 1 Package\n",
286
+ "\n",
287
+ "Total download size: 3.7 M\n",
288
+ "Installed size: 13 M\n",
289
+ "Downloading packages:\n"
290
+ ]
291
+ },
292
+ {
293
+ "name": "stderr",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "warning: /var/cache/yum/x86_64/2/epel/packages/git-lfs-2.10.0-2.el7.x86_64.rpm: Header V4 RSA/SHA256 Signature, key ID 352c64e5: NOKEY\n",
297
+ "Importing GPG key 0x352C64E5:\n",
298
+ " Userid : \"Fedora EPEL (7) <epel@fedoraproject.org>\"\n",
299
+ " Fingerprint: 91e9 7d7c 4a5e 96f1 7f3e 888f 6a2f aea2 352c 64e5\n",
300
+ " Package : epel-release-7-11.noarch (@amzn2extra-epel)\n",
301
+ " From : /etc/pki/rpm-gpg/RPM-GPG-KEY-EPEL-7\n"
302
+ ]
303
+ },
304
+ {
305
+ "name": "stdout",
306
+ "output_type": "stream",
307
+ "text": [
308
+ "Public key for git-lfs-2.10.0-2.el7.x86_64.rpm is not installed\n",
309
+ "Retrieving key from file:///etc/pki/rpm-gpg/RPM-GPG-KEY-EPEL-7\n",
310
+ "Running transaction check\n",
311
+ "Running transaction test\n",
312
+ "Transaction test succeeded\n",
313
+ "Running transaction\n",
314
+ " Installing : git-lfs-2.10.0-2.el7.x86_64 1/1 \n",
315
+ " Verifying : git-lfs-2.10.0-2.el7.x86_64 1/1 \n",
316
+ "\n",
317
+ "Installed:\n",
318
+ " git-lfs.x86_64 0:2.10.0-2.el7 \n",
319
+ "\n",
320
+ "Complete!\n",
321
+ "Git LFS initialized.\n"
322
+ ]
323
+ },
324
+ {
325
+ "data": {
326
+ "text/plain": [
327
+ "0"
328
+ ]
329
+ },
330
+ "execution_count": 63,
331
+ "metadata": {},
332
+ "output_type": "execute_result"
333
+ }
334
+ ],
335
+ "source": [
336
+ "# install git-lfs\n",
337
+ "os.system('sudo amazon-linux-extras install epel -y')\n",
338
+ "os.system('sudo yum-config-manager --enable epel')\n",
339
+ "os.system('sudo yum install git-lfs -y')\n",
340
+ "os.system('git lfs install')"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": 3,
346
+ "id": "6595a6c5-9fae-4bd4-b096-53475ec98294",
347
+ "metadata": {
348
+ "tags": []
349
+ },
350
+ "outputs": [
351
+ {
352
+ "name": "stderr",
353
+ "output_type": "stream",
354
+ "text": [
355
+ "Cloning into 'stackexchange_title_body_jsonl'...\n",
356
+ "Cloning into 'stackexchange_titlebody_best_voted_answer_jsonl'...\n",
357
+ "Cloning into 'stackexchange_title_best_voted_answer_jsonl'...\n",
358
+ "Cloning into 'stackexchange_titlebody_best_and_down_voted_answer_jsonl'...\n",
359
+ "Cloning into 'reddit-title-body'...\n",
360
+ "Cloning into '1B_sentence_embeddings'...\n"
361
+ ]
362
+ },
363
+ {
364
+ "data": {
365
+ "text/plain": [
366
+ "0"
367
+ ]
368
+ },
369
+ "execution_count": 3,
370
+ "metadata": {},
371
+ "output_type": "execute_result"
372
+ }
373
+ ],
374
+ "source": [
375
+ "# clone relevant datasets' github repositories\n",
376
+ "stacks = ['stackexchange_title_body_jsonl', #25.3M \n",
377
+ " 'stackexchange_titlebody_best_voted_answer_jsonl', #4.75M\n",
378
+ " 'stackexchange_title_best_voted_answer_jsonl', #4.75M\n",
379
+ " 'stackexchange_titlebody_best_and_down_voted_answer_jsonl'] #210K\n",
380
+ "\n",
381
+ "os.environ['GIT_LFS_SKIP_SMUDGE'] = \"1\"\n",
382
+ "\n",
383
+ "# clone stackexchange repos\n",
384
+ "for stack in stacks:\n",
385
+ " os.system(f'git clone https://huggingface.co/datasets/flax-sentence-embeddings/{stack}')\n",
386
+ "# clone reddit repo\n",
387
+ "os.system('git clone https://huggingface.co/datasets/sentence-transformers/reddit-title-body')\n",
388
+ "# clone 1B+ sentence embeddings repo (this one is just for reference)\n",
389
+ "os.system('git clone https://github.com/AntoineSimoulin/1B_sentence_embeddings')"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": null,
395
+ "id": "56438029-0f26-491c-b405-f92fb9caeff7",
396
+ "metadata": {
397
+ "collapsed": true,
398
+ "jupyter": {
399
+ "outputs_hidden": true
400
+ },
401
+ "tags": []
402
+ },
403
+ "outputs": [
404
+ {
405
+ "name": "stdout",
406
+ "output_type": "stream",
407
+ "text": [
408
+ "Downloading 4 StackExchange GitHub datasets into s3://lodestone-rnd/data/\n",
409
+ "upload: ./networkengineering.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/networkengineering.stackexchange.com.json.gz\n",
410
+ "upload: ./emacs.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/emacs.stackexchange.com.json.gz\n",
411
+ "upload: ./christianity.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/christianity.stackexchange.com.json.gz\n",
412
+ "upload: ./bitcoin.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/bitcoin.stackexchange.com.json.gz\n",
413
+ "upload: ./academia.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/academia.stackexchange.com.json.gz\n",
414
+ "upload: ./music.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/music.stackexchange.com.json.gz\n",
415
+ "upload: ./biology.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/biology.stackexchange.com.json.gz\n",
416
+ "upload: ./history.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/history.stackexchange.com.json.gz\n",
417
+ "upload: ./skeptics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/skeptics.stackexchange.com.json.gz\n",
418
+ "upload: ./anime.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/anime.stackexchange.com.json.gz\n",
419
+ "upload: ./quant.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/quant.stackexchange.com.json.gz\n",
420
+ "upload: ./boardgames.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/boardgames.stackexchange.com.json.gz\n",
421
+ "upload: ./judaism.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/judaism.stackexchange.com.json.gz\n",
422
+ "upload: ./travel.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/travel.stackexchange.com.json.gz\n",
423
+ "upload: ./gaming.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/gaming.stackexchange.com.json.gz\n",
424
+ "upload: ./webapps.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/webapps.stackexchange.com.json.gz\n",
425
+ "upload: ./stats.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/stats.stackexchange.com.json.gz\n",
426
+ "upload: ./law.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/law.stackexchange.com.json.gz\n",
427
+ "upload: ./scifi.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/scifi.stackexchange.com.json.gz\n",
428
+ "upload: ./bicycles.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/bicycles.stackexchange.com.json.gz\n",
429
+ "upload: ./datascience.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/datascience.stackexchange.com.json.gz\n",
430
+ "upload: ./softwareengineering.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/softwareengineering.stackexchange.com.json.gz\n",
431
+ "upload: ./islam.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/islam.stackexchange.com.json.gz\n",
432
+ "upload: ./craftcms.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/craftcms.stackexchange.com.json.gz\n",
433
+ "upload: ./diy.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/diy.stackexchange.com.json.gz\n",
434
+ "upload: ./arduino.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/arduino.stackexchange.com.json.gz\n",
435
+ "upload: ./raspberrypi.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/raspberrypi.stackexchange.com.json.gz\n",
436
+ "upload: ./wordpress.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/wordpress.stackexchange.com.json.gz\n",
437
+ "upload: ./dba.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/dba.stackexchange.com.json.gz\n",
438
+ "upload: ./apple.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/apple.stackexchange.com.json.gz\n",
439
+ "upload: ./hinduism.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/hinduism.stackexchange.com.json.gz\n",
440
+ "upload: ./mechanics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/mechanics.stackexchange.com.json.gz\n",
441
+ "upload: ./gamedev.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/gamedev.stackexchange.com.json.gz\n",
442
+ "upload: ./writers.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/writers.stackexchange.com.json.gz\n",
443
+ "upload: ./mathematica.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/mathematica.stackexchange.com.json.gz\n",
444
+ "upload: ./unix.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/unix.stackexchange.com.json.gz\n",
445
+ "upload: ./magento.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/magento.stackexchange.com.json.gz\n",
446
+ "upload: ./ethereum.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/ethereum.stackexchange.com.json.gz\n",
447
+ "upload: ./electronics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/electronics.stackexchange.com.json.gz\n",
448
+ "upload: ./cs.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/cs.stackexchange.com.json.gz\n",
449
+ "upload: ./blender.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/blender.stackexchange.com.json.gz\n",
450
+ "upload: ./drupal.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/drupal.stackexchange.com.json.gz\n",
451
+ "upload: ./small_stackexchanges.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/small_stackexchanges.json.gz\n",
452
+ "upload: ./photo.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/photo.stackexchange.com.json.gz\n",
453
+ "upload: ./engineering.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/engineering.stackexchange.com.json.gz\n",
454
+ "upload: ./ux.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/ux.stackexchange.com.json.gz\n",
455
+ "upload: ./german.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/german.stackexchange.com.json.gz\n",
456
+ "upload: ./japanese.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/japanese.stackexchange.com.json.gz\n",
457
+ "upload: ./civicrm.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/civicrm.stackexchange.com.json.gz\n",
458
+ "upload: ./sharepoint.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/sharepoint.stackexchange.com.json.gz\n",
459
+ "upload: ./mathoverflow.net.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/mathoverflow.net.json.gz\n",
460
+ "upload: ./meta.stackoverflow.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/meta.stackoverflow.com.json.gz\n",
461
+ "upload: ./rpg.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/rpg.stackexchange.com.json.gz\n",
462
+ "upload: ./crypto.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/crypto.stackexchange.com.json.gz\n",
463
+ "upload: ./vi.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/vi.stackexchange.com.json.gz\n",
464
+ "upload: ./graphicdesign.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/graphicdesign.stackexchange.com.json.gz\n",
465
+ "upload: ./cooking.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/cooking.stackexchange.com.json.gz\n",
466
+ "upload: ./math.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/math.stackexchange.com.json.gz\n",
467
+ "upload: ./expressionengine.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/expressionengine.stackexchange.com.json.gz\n",
468
+ "upload: ./movies.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/movies.stackexchange.com.json.gz\n",
469
+ "upload: ./salesforce.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/salesforce.stackexchange.com.json.gz\n",
470
+ "upload: ./physics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/physics.stackexchange.com.json.gz\n",
471
+ "upload: ./aviation.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/aviation.stackexchange.com.json.gz\n",
472
+ "upload: ./gardening.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/gardening.stackexchange.com.json.gz\n",
473
+ "upload: ./english.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/english.stackexchange.com.json.gz\n",
474
+ "upload: ./askubuntu.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/askubuntu.com.json.gz\n",
475
+ "upload: ./french.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/french.stackexchange.com.json.gz\n",
476
+ "upload: ./codereview.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/codereview.stackexchange.com.json.gz\n",
477
+ "upload: ./softwarerecs.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/softwarerecs.stackexchange.com.json.gz\n",
478
+ "upload: ./rus.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/rus.stackexchange.com.json.gz\n",
479
+ "upload: ./money.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/money.stackexchange.com.json.gz\n",
480
+ "upload: ./philosophy.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/philosophy.stackexchange.com.json.gz\n",
481
+ "upload: ./chemistry.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/chemistry.stackexchange.com.json.gz\n",
482
+ "upload: ./meta.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/meta.stackexchange.com.json.gz\n",
483
+ "upload: ./cstheory.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/cstheory.stackexchange.com.json.gz\n",
484
+ "upload: ./space.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/space.stackexchange.com.json.gz\n",
485
+ "upload: ./politics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/politics.stackexchange.com.json.gz\n",
486
+ "upload: ./ell.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/ell.stackexchange.com.json.gz\n",
487
+ "upload: ./puzzling.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/puzzling.stackexchange.com.json.gz\n",
488
+ "upload: ./astronomy.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/astronomy.stackexchange.com.json.gz\n",
489
+ "upload: ./worldbuilding.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/worldbuilding.stackexchange.com.json.gz\n",
490
+ "upload: ./economics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/economics.stackexchange.com.json.gz\n",
491
+ "upload: ./workplace.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/workplace.stackexchange.com.json.gz\n",
492
+ "upload: ./tex.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/tex.stackexchange.com.json.gz\n",
493
+ "upload: ./android.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/android.stackexchange.com.json.gz\n",
494
+ "upload: ./gis.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/gis.stackexchange.com.json.gz\n",
495
+ "upload: ./dsp.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/dsp.stackexchange.com.json.gz\n",
496
+ "upload: ./superuser.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/superuser.com.json.gz\n",
497
+ "upload: ./english.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/english.stackexchange.com.json.gz\n",
498
+ "upload: ./meta.serverfault.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/meta.serverfault.com.json.gz\n",
499
+ "upload: ./scicomp.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/scicomp.stackexchange.com.json.gz\n",
500
+ "upload: ./askubuntu.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/askubuntu.com.json.gz\n",
501
+ "upload: ./french.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/french.stackexchange.com.json.gz\n",
502
+ "upload: ./coffee.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/coffee.stackexchange.com.json.gz\n",
503
+ "upload: ./codereview.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/codereview.stackexchange.com.json.gz\n",
504
+ "upload: ./sound.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/sound.stackexchange.com.json.gz\n",
505
+ "upload: ./opensource.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/opensource.stackexchange.com.json.gz\n",
506
+ "upload: ./woodworking.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/woodworking.stackexchange.com.json.gz\n",
507
+ "upload: ./outdoors.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/outdoors.stackexchange.com.json.gz\n",
508
+ "upload: ./reddit_title_text_2018.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2018.json.gz\n",
509
+ "upload: ./reddit_title_text_2011.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2011.json.gz\n",
510
+ "upload: ./reddit_title_text_2020.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2020.json.gz\n",
511
+ "upload: ./reddit_title_text_2012.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2012.json.gz\n",
512
+ "upload: ./reddit_title_text_2021.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2021.json.gz\n",
513
+ "upload: ./reddit_title_text_2019.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2019.json.gz\n",
514
+ "upload: ./reddit_title_text_2010.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2010.json.gz\n",
515
+ "upload: ./reddit_title_text_2014.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2014.json.gz\n",
516
+ "\u001b[32mDone\u001b[0m\n",
517
+ "Total files uploaded: 12\n",
518
+ "\n",
519
+ "\n",
520
+ "Downloading 9 HuggingFace datasets into s3://lodestone-rnd/data/\n",
521
+ "Downloading dataset S2ORC_citations_abstracts (39,567,485 pairs) ... "
522
+ ]
523
+ }
524
+ ],
525
+ "source": [
526
+ "# DOWNLOAD GITHUB DATASETS (STACKEXCHANGE (https://huggingface.co/flax-sentence-embeddings) & REDDIT (https://huggingface.co/datasets/sentence-transformers/reddit-title-body))\n",
527
+ "\n",
528
+ "# these are the files marked as unsafe by HuggingFace when viewing the each of the datasets' pages\n",
529
+ "unsafe = [[\"serverfault.com.jsonl.gz\", \"security.stackexchange.com.jsonl.gz\"],\n",
530
+ " [\"monero.stackexchange.com.jsonl.gz\", \"serverfault.com.jsonl.gz\", \"security.stackexchange.com.jsonl.gz\"],\n",
531
+ " [\"elementaryos.stackexchange.com.jsonl.gz\", \"monero.stackexchange.com.jsonl.gz\", \"security.stackexchange.com.jsonl.gz\"],\n",
532
+ " [\"\"]]\n",
533
+ "\n",
534
+ "file_counts = []\n",
535
+ "print('Downloading {:,} StackExchange GitHub datasets into s3://lodestone-rnd/data/'.format(len(stacks)))\n",
536
+ "for i, stack in enumerate(stacks):\n",
537
+ " # get the names of all the files in the repository that are not unsafe\n",
538
+ " files = [file for file in os.listdir(f'/home/ec2-user/SageMaker/{stack}') if file.endswith(\".jsonl.gz\")==True if file not in unsafe[i]]\n",
539
+ " file_counts.append(len(files))\n",
540
+ " os.chdir(f'/home/ec2-user/SageMaker/{stack}')\n",
541
+ " print('Downloading dataset {} ({} files) ... '.format(stack, len(files)), end='', flush=True)\n",
542
+ " # sequentially pull each dataset from git lfs, stream it to S3, and then delete the local copy to free up disk memory\n",
543
+ " for file_name in files:\n",
544
+ " os.system(f'git lfs pull --include={file_name}')\n",
545
+ " os.system(f'aws s3 cp {file_name} s3://lodestone-rnd/data/{stack}/{file_name[:-9] + \".json.gz\"}')\n",
546
+ " os.remove(file_name)\n",
547
+ " os.system('rm -r .git/lfs/objects/*')\n",
548
+ " if len(os.listdir('.git/objects/pack')) == 4:\n",
549
+ " os.system('ls -t .git/objects/pack/* | head -2 | xargs rm --')\n",
550
+ " print('\\033[32m' + 'Done' + '\\033[0m')\n",
551
+ "print(f'Total files uploaded: {sum(file_counts)}')\n",
552
+ "\n",
553
+ "print(\"\\n\")\n",
554
+ "\n",
555
+ "print('Downloading {:,} Reddit GitHub dataset into s3://lodestone-rnd/data/'.format(1))\n",
556
+ "# get the names of all the files in the repository\n",
557
+ "files = [file for file in os.listdir(f'/home/ec2-user/SageMaker/reddit-title-body') if file.endswith(\".jsonl.gz\")==True]\n",
558
+ "os.chdir(f'/home/ec2-user/SageMaker/reddit-title-body')\n",
559
+ "print('Downloading dataset {} ({} files) ... '.format(\"reddit-title-body\", len(files)), end='', flush=True)\n",
560
+ "# sequentially pull each dataset from git lfs, stream it to S3, and then delete the local copy to free up disk memory\n",
561
+ "for file_name in files:\n",
562
+ " os.system(f'git lfs pull --include={file_name}')\n",
563
+ " os.system(f'aws s3 cp {file_name} s3://lodestone-rnd/data/reddit-title-body/{file_name[:-9] + \".json.gz\"}')\n",
564
+ " os.remove(file_name)\n",
565
+ " os.system('rm -r .git/lfs/objects/*')\n",
566
+ " if len(os.listdir('.git/objects/pack')) == 4:\n",
567
+ " os.system('ls -t .git/objects/pack/* | head -2 | xargs rm --')\n",
568
+ "print('\\033[32m' + 'Done' + '\\033[0m')\n",
569
+ "print(f'Total files uploaded: {len(files)}')\n",
570
+ "\n",
571
+ "os.chdir('/home/ec2-user/SageMaker')\n",
572
+ "\n",
573
+ "print(\"\\n\")\n",
574
+ "\n",
575
+ "# DOWNLOAD HUGGINGFACE DATASETS (https://huggingface.co/datasets/sentence-transformers/embedding-training-data)\n",
576
+ "\n",
577
+ "# read dataset information from HuggingFace_datasets.tsv\n",
578
+ "datasets = pd.read_csv(\n",
579
+ " 'HuggingFace_datasets.tsv',\n",
580
+ " index_col=0,\n",
581
+ " sep='\\t',\n",
582
+ " dtype={\n",
583
+ " 'Description': str,\n",
584
+ " 'Size (#Pairs)': str,\n",
585
+ " 'Performance': float,\n",
586
+ " 'Download link': str,\n",
587
+ " 'Source': str})\n",
588
+ "datasets['Size (#Pairs)'] = datasets['Size (#Pairs)'].str.replace(',', '').astype(int)\n",
589
+ "datasets = datasets.to_dict(orient='index')\n",
590
+ "\n",
591
+ "print('Downloading {:,} HuggingFace datasets into s3://lodestone-rnd/data/'.format(len(datasets)))\n",
592
+ "\n",
593
+ "# stream each of the datasets from the URL provided into S3\n",
594
+ "# (note that S2ORC_citations_abstracts is larger than 50GB and therefore requires the expected size to be passed into the command line as well)\n",
595
+ "for d in datasets.keys():\n",
596
+ " print('Downloading dataset {} ({:,} pairs) ... '.format(d, datasets[d]['Size (#Pairs)']), end='', flush=True)\n",
597
+ " if d == \"S2ORC_citations_abstracts\":\n",
598
+ " os.system(f'wget -qO- {datasets[d][\"Download link\"]} | aws s3 cp - s3://lodestone-rnd/data/{d + \".json.gz\"} --expected-size 120259084288')\n",
599
+ " else:\n",
600
+ " os.system(f'wget -qO- {datasets[d][\"Download link\"]} | aws s3 cp - s3://lodestone-rnd/data/{d + \".json.gz\"}')\n",
601
+ " print('\\033[32m' + 'Done' + '\\033[0m')\n",
602
+ "\n",
603
+ "print(\"\\n\")\n",
604
+ "\n",
605
+ "# DOWNLOAD GOOGLE SHEETS DATASETS (https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0)\n",
606
+ "\n",
607
+ "# read dataset information from GoogleSheets_datasets.tsv\n",
608
+ "datasets = pd.read_csv(\n",
609
+ " 'GoogleSheets_datasets.tsv',\n",
610
+ " index_col=0,\n",
611
+ " sep='\\t',\n",
612
+ " dtype={\n",
613
+ " 'Description': str,\n",
614
+ " 'Size (#Pairs)': str,\n",
615
+ " 'Performance': float,\n",
616
+ " 'Download link': str,\n",
617
+ " 'Source': str})\n",
618
+ "datasets['Size (#Pairs)'] = datasets['Size (#Pairs)'].str.replace(',', '').astype(int)\n",
619
+ "datasets = datasets.to_dict(orient='index')\n",
620
+ "\n",
621
+ "print('Downloading {:,} 1B+ Google Sheets datasets into s3://lodestone-rnd/data/'.format(len(datasets)))\n",
622
+ "\n",
623
+ "# stream each of the datasets from the URL provided into S3\n",
624
+ "for d in datasets.keys():\n",
625
+ " print('Downloading dataset {} ({:,} pairs) ... '.format(d, datasets[d]['Size (#Pairs)']), end='', flush=True)\n",
626
+ " os.system(f'wget -qO- {datasets[d][\"Download link\"]} | aws s3 cp - s3://lodestone-rnd/data/{d + \".json.gz\"}')\n",
627
+ " print('\\033[32m' + 'Done' + '\\033[0m')\n",
628
+ "\n",
629
+ "print(\"\\n\")\n",
630
+ " \n",
631
+ "print(f'Successfully downloaded 50 datasets and {621+12+9+36} files into s3://lodestone-rnd/data/')"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": 1,
637
+ "id": "216d4564-6b69-4688-94ec-a67287134a2d",
638
+ "metadata": {
639
+ "tags": []
640
+ },
641
+ "outputs": [],
642
+ "source": [
643
+ "# clean up (remove the cloned repositories)\n",
644
+ "import shutil\n",
645
+ "shutil.rmtree(\"/home/ec2-user/SageMaker/stackexchange_title_body_jsonl\")\n",
646
+ "shutil.rmtree(\"/home/ec2-user/SageMaker/stackexchange_titlebody_best_voted_answer_jsonl\")\n",
647
+ "shutil.rmtree(\"/home/ec2-user/SageMaker/stackexchange_title_best_voted_answer_jsonl\")\n",
648
+ "shutil.rmtree(\"/home/ec2-user/SageMaker/stackexchange_titlebody_best_and_down_voted_answer_jsonl\")\n",
649
+ "shutil.rmtree(\"/home/ec2-user/SageMaker/reddit-title-body\")\n",
650
+ "shutil.rmtree(\"/home/ec2-user/SageMaker/1B_sentence_embeddings\")"
651
+ ]
652
+ }
653
+ ],
654
+ "metadata": {
655
+ "kernelspec": {
656
+ "display_name": "conda_pytorch_p310",
657
+ "language": "python",
658
+ "name": "conda_pytorch_p310"
659
+ },
660
+ "language_info": {
661
+ "codemirror_mode": {
662
+ "name": "ipython",
663
+ "version": 3
664
+ },
665
+ "file_extension": ".py",
666
+ "mimetype": "text/x-python",
667
+ "name": "python",
668
+ "nbconvert_exporter": "python",
669
+ "pygments_lexer": "ipython3",
670
+ "version": "3.10.10"
671
+ }
672
+ },
673
+ "nbformat": 4,
674
+ "nbformat_minor": 5
675
+ }
README.md CHANGED
@@ -1,3 +1,1940 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: sentence-similarity
4
+ tags:
5
+ - sentence-transformers
6
+ - feature-extraction
7
+ - sentence-similarity
8
+ - mteb
9
+ language: en
10
+ datasets:
11
+ - s2orc
12
+ - flax-sentence-embeddings/stackexchange_title_body_jsonl
13
+ - flax-sentence-embeddings/stackexchange_titlebody_best_voted_answer_jsonl
14
+ - flax-sentence-embeddings/stackexchange_title_best_voted_answer_jsonl
15
+ - flax-sentence-embeddings/stackexchange_titlebody_best_and_down_voted_answer_jsonl
16
+ - sentence-transformers/reddit-title-body
17
+ - msmarco
18
+ - gooaq
19
+ - yahoo_answers_topics
20
+ - code_search_net
21
+ - search_qa
22
+ - eli5
23
+ - snli
24
+ - multi_nli
25
+ - wikihow
26
+ - natural_questions
27
+ - trivia_qa
28
+ - embedding-data/sentence-compression
29
+ - embedding-data/flickr30k-captions
30
+ - embedding-data/altlex
31
+ - embedding-data/simple-wiki
32
+ - embedding-data/QQP
33
+ - embedding-data/SPECTER
34
+ - embedding-data/PAQ_pairs
35
+ - embedding-data/WikiAnswers
36
+ - sentence-transformers/embedding-training-data
37
+ model-index:
38
+ - name: lodestone-base-4096-v1
39
+ results:
40
+ - task:
41
+ type: Classification
42
+ dataset:
43
+ type: mteb/amazon_counterfactual
44
+ name: MTEB AmazonCounterfactualClassification (en)
45
+ config: en
46
+ split: test
47
+ revision: e8379541af4e31359cca9fbcf4b00f2671dba205
48
+ metrics:
49
+ - type: accuracy
50
+ value: 69.7313432835821
51
+ - type: ap
52
+ value: 31.618259511417733
53
+ - type: f1
54
+ value: 63.30313825394228
55
+ - task:
56
+ type: Classification
57
+ dataset:
58
+ type: mteb/amazon_polarity
59
+ name: MTEB AmazonPolarityClassification
60
+ config: default
61
+ split: test
62
+ revision: e2d317d38cd51312af73b3d32a06d1a08b442046
63
+ metrics:
64
+ - type: accuracy
65
+ value: 86.89837499999999
66
+ - type: ap
67
+ value: 82.39500885672128
68
+ - type: f1
69
+ value: 86.87317947399657
70
+ - task:
71
+ type: Classification
72
+ dataset:
73
+ type: mteb/amazon_reviews_multi
74
+ name: MTEB AmazonReviewsClassification (en)
75
+ config: en
76
+ split: test
77
+ revision: 1399c76144fd37290681b995c656ef9b2e06e26d
78
+ metrics:
79
+ - type: accuracy
80
+ value: 44.05
81
+ - type: f1
82
+ value: 42.67624383248947
83
+ - task:
84
+ type: Retrieval
85
+ dataset:
86
+ type: arguana
87
+ name: MTEB ArguAna
88
+ config: default
89
+ split: test
90
+ revision: None
91
+ metrics:
92
+ - type: map_at_1
93
+ value: 26.173999999999996
94
+ - type: map_at_10
95
+ value: 40.976
96
+ - type: map_at_100
97
+ value: 42.067
98
+ - type: map_at_1000
99
+ value: 42.075
100
+ - type: map_at_3
101
+ value: 35.917
102
+ - type: map_at_5
103
+ value: 38.656
104
+ - type: mrr_at_1
105
+ value: 26.814
106
+ - type: mrr_at_10
107
+ value: 41.252
108
+ - type: mrr_at_100
109
+ value: 42.337
110
+ - type: mrr_at_1000
111
+ value: 42.345
112
+ - type: mrr_at_3
113
+ value: 36.226
114
+ - type: mrr_at_5
115
+ value: 38.914
116
+ - type: ndcg_at_1
117
+ value: 26.173999999999996
118
+ - type: ndcg_at_10
119
+ value: 49.819
120
+ - type: ndcg_at_100
121
+ value: 54.403999999999996
122
+ - type: ndcg_at_1000
123
+ value: 54.59
124
+ - type: ndcg_at_3
125
+ value: 39.231
126
+ - type: ndcg_at_5
127
+ value: 44.189
128
+ - type: precision_at_1
129
+ value: 26.173999999999996
130
+ - type: precision_at_10
131
+ value: 7.838000000000001
132
+ - type: precision_at_100
133
+ value: 0.9820000000000001
134
+ - type: precision_at_1000
135
+ value: 0.1
136
+ - type: precision_at_3
137
+ value: 16.287
138
+ - type: precision_at_5
139
+ value: 12.191
140
+ - type: recall_at_1
141
+ value: 26.173999999999996
142
+ - type: recall_at_10
143
+ value: 78.378
144
+ - type: recall_at_100
145
+ value: 98.222
146
+ - type: recall_at_1000
147
+ value: 99.644
148
+ - type: recall_at_3
149
+ value: 48.862
150
+ - type: recall_at_5
151
+ value: 60.953
152
+ - task:
153
+ type: Clustering
154
+ dataset:
155
+ type: mteb/arxiv-clustering-p2p
156
+ name: MTEB ArxivClusteringP2P
157
+ config: default
158
+ split: test
159
+ revision: a122ad7f3f0291bf49cc6f4d32aa80929df69d5d
160
+ metrics:
161
+ - type: v_measure
162
+ value: 42.31689035788179
163
+ - task:
164
+ type: Clustering
165
+ dataset:
166
+ type: mteb/arxiv-clustering-s2s
167
+ name: MTEB ArxivClusteringS2S
168
+ config: default
169
+ split: test
170
+ revision: f910caf1a6075f7329cdf8c1a6135696f37dbd53
171
+ metrics:
172
+ - type: v_measure
173
+ value: 31.280245136660984
174
+ - task:
175
+ type: Reranking
176
+ dataset:
177
+ type: mteb/askubuntudupquestions-reranking
178
+ name: MTEB AskUbuntuDupQuestions
179
+ config: default
180
+ split: test
181
+ revision: 2000358ca161889fa9c082cb41daa8dcfb161a54
182
+ metrics:
183
+ - type: map
184
+ value: 58.79109720839415
185
+ - type: mrr
186
+ value: 71.79615705931495
187
+ - task:
188
+ type: STS
189
+ dataset:
190
+ type: mteb/biosses-sts
191
+ name: MTEB BIOSSES
192
+ config: default
193
+ split: test
194
+ revision: d3fb88f8f02e40887cd149695127462bbcf29b4a
195
+ metrics:
196
+ - type: cos_sim_pearson
197
+ value: 76.44918756608115
198
+ - type: cos_sim_spearman
199
+ value: 70.86607256286257
200
+ - type: euclidean_pearson
201
+ value: 74.12154678100815
202
+ - type: euclidean_spearman
203
+ value: 70.86607256286257
204
+ - type: manhattan_pearson
205
+ value: 74.0078626964417
206
+ - type: manhattan_spearman
207
+ value: 70.68353828321327
208
+ - task:
209
+ type: Classification
210
+ dataset:
211
+ type: mteb/banking77
212
+ name: MTEB Banking77Classification
213
+ config: default
214
+ split: test
215
+ revision: 0fd18e25b25c072e09e0d92ab615fda904d66300
216
+ metrics:
217
+ - type: accuracy
218
+ value: 75.40584415584415
219
+ - type: f1
220
+ value: 74.29514617572676
221
+ - task:
222
+ type: Clustering
223
+ dataset:
224
+ type: mteb/biorxiv-clustering-p2p
225
+ name: MTEB BiorxivClusteringP2P
226
+ config: default
227
+ split: test
228
+ revision: 65b79d1d13f80053f67aca9498d9402c2d9f1f40
229
+ metrics:
230
+ - type: v_measure
231
+ value: 37.41860080664014
232
+ - task:
233
+ type: Clustering
234
+ dataset:
235
+ type: mteb/biorxiv-clustering-s2s
236
+ name: MTEB BiorxivClusteringS2S
237
+ config: default
238
+ split: test
239
+ revision: 258694dd0231531bc1fd9de6ceb52a0853c6d908
240
+ metrics:
241
+ - type: v_measure
242
+ value: 29.319217023090705
243
+ - task:
244
+ type: Retrieval
245
+ dataset:
246
+ type: BeIR/cqadupstack
247
+ name: MTEB CQADupstackEnglishRetrieval
248
+ config: default
249
+ split: test
250
+ revision: None
251
+ metrics:
252
+ - type: map_at_1
253
+ value: 22.528000000000002
254
+ - type: map_at_10
255
+ value: 30.751
256
+ - type: map_at_100
257
+ value: 31.855
258
+ - type: map_at_1000
259
+ value: 31.972
260
+ - type: map_at_3
261
+ value: 28.465
262
+ - type: map_at_5
263
+ value: 29.738
264
+ - type: mrr_at_1
265
+ value: 28.662
266
+ - type: mrr_at_10
267
+ value: 35.912
268
+ - type: mrr_at_100
269
+ value: 36.726
270
+ - type: mrr_at_1000
271
+ value: 36.777
272
+ - type: mrr_at_3
273
+ value: 34.013
274
+ - type: mrr_at_5
275
+ value: 35.156
276
+ - type: ndcg_at_1
277
+ value: 28.662
278
+ - type: ndcg_at_10
279
+ value: 35.452
280
+ - type: ndcg_at_100
281
+ value: 40.1
282
+ - type: ndcg_at_1000
283
+ value: 42.323
284
+ - type: ndcg_at_3
285
+ value: 32.112
286
+ - type: ndcg_at_5
287
+ value: 33.638
288
+ - type: precision_at_1
289
+ value: 28.662
290
+ - type: precision_at_10
291
+ value: 6.688
292
+ - type: precision_at_100
293
+ value: 1.13
294
+ - type: precision_at_1000
295
+ value: 0.16
296
+ - type: precision_at_3
297
+ value: 15.562999999999999
298
+ - type: precision_at_5
299
+ value: 11.019
300
+ - type: recall_at_1
301
+ value: 22.528000000000002
302
+ - type: recall_at_10
303
+ value: 43.748
304
+ - type: recall_at_100
305
+ value: 64.235
306
+ - type: recall_at_1000
307
+ value: 78.609
308
+ - type: recall_at_3
309
+ value: 33.937
310
+ - type: recall_at_5
311
+ value: 38.234
312
+ - task:
313
+ type: Retrieval
314
+ dataset:
315
+ type: climate-fever
316
+ name: MTEB ClimateFEVER
317
+ config: default
318
+ split: test
319
+ revision: None
320
+ metrics:
321
+ - type: map_at_1
322
+ value: 9.468
323
+ - type: map_at_10
324
+ value: 16.029
325
+ - type: map_at_100
326
+ value: 17.693
327
+ - type: map_at_1000
328
+ value: 17.886
329
+ - type: map_at_3
330
+ value: 13.15
331
+ - type: map_at_5
332
+ value: 14.568
333
+ - type: mrr_at_1
334
+ value: 21.173000000000002
335
+ - type: mrr_at_10
336
+ value: 31.028
337
+ - type: mrr_at_100
338
+ value: 32.061
339
+ - type: mrr_at_1000
340
+ value: 32.119
341
+ - type: mrr_at_3
342
+ value: 27.534999999999997
343
+ - type: mrr_at_5
344
+ value: 29.431
345
+ - type: ndcg_at_1
346
+ value: 21.173000000000002
347
+ - type: ndcg_at_10
348
+ value: 23.224
349
+ - type: ndcg_at_100
350
+ value: 30.225
351
+ - type: ndcg_at_1000
352
+ value: 33.961000000000006
353
+ - type: ndcg_at_3
354
+ value: 18.174
355
+ - type: ndcg_at_5
356
+ value: 19.897000000000002
357
+ - type: precision_at_1
358
+ value: 21.173000000000002
359
+ - type: precision_at_10
360
+ value: 7.4719999999999995
361
+ - type: precision_at_100
362
+ value: 1.5010000000000001
363
+ - type: precision_at_1000
364
+ value: 0.219
365
+ - type: precision_at_3
366
+ value: 13.312
367
+ - type: precision_at_5
368
+ value: 10.619
369
+ - type: recall_at_1
370
+ value: 9.468
371
+ - type: recall_at_10
372
+ value: 28.823
373
+ - type: recall_at_100
374
+ value: 53.26499999999999
375
+ - type: recall_at_1000
376
+ value: 74.536
377
+ - type: recall_at_3
378
+ value: 16.672
379
+ - type: recall_at_5
380
+ value: 21.302
381
+ - task:
382
+ type: Retrieval
383
+ dataset:
384
+ type: dbpedia-entity
385
+ name: MTEB DBPedia
386
+ config: default
387
+ split: test
388
+ revision: None
389
+ metrics:
390
+ - type: map_at_1
391
+ value: 6.343
392
+ - type: map_at_10
393
+ value: 12.717
394
+ - type: map_at_100
395
+ value: 16.48
396
+ - type: map_at_1000
397
+ value: 17.381
398
+ - type: map_at_3
399
+ value: 9.568999999999999
400
+ - type: map_at_5
401
+ value: 11.125
402
+ - type: mrr_at_1
403
+ value: 48.75
404
+ - type: mrr_at_10
405
+ value: 58.425000000000004
406
+ - type: mrr_at_100
407
+ value: 59.075
408
+ - type: mrr_at_1000
409
+ value: 59.095
410
+ - type: mrr_at_3
411
+ value: 56.291999999999994
412
+ - type: mrr_at_5
413
+ value: 57.679
414
+ - type: ndcg_at_1
415
+ value: 37.875
416
+ - type: ndcg_at_10
417
+ value: 27.77
418
+ - type: ndcg_at_100
419
+ value: 30.288999999999998
420
+ - type: ndcg_at_1000
421
+ value: 36.187999999999995
422
+ - type: ndcg_at_3
423
+ value: 31.385999999999996
424
+ - type: ndcg_at_5
425
+ value: 29.923
426
+ - type: precision_at_1
427
+ value: 48.75
428
+ - type: precision_at_10
429
+ value: 22.375
430
+ - type: precision_at_100
431
+ value: 6.3420000000000005
432
+ - type: precision_at_1000
433
+ value: 1.4489999999999998
434
+ - type: precision_at_3
435
+ value: 35.5
436
+ - type: precision_at_5
437
+ value: 30.55
438
+ - type: recall_at_1
439
+ value: 6.343
440
+ - type: recall_at_10
441
+ value: 16.936
442
+ - type: recall_at_100
443
+ value: 35.955999999999996
444
+ - type: recall_at_1000
445
+ value: 55.787
446
+ - type: recall_at_3
447
+ value: 10.771
448
+ - type: recall_at_5
449
+ value: 13.669999999999998
450
+ - task:
451
+ type: Classification
452
+ dataset:
453
+ type: mteb/emotion
454
+ name: MTEB EmotionClassification
455
+ config: default
456
+ split: test
457
+ revision: 4f58c6b202a23cf9a4da393831edf4f9183cad37
458
+ metrics:
459
+ - type: accuracy
460
+ value: 41.99
461
+ - type: f1
462
+ value: 36.823402174564954
463
+ - task:
464
+ type: Retrieval
465
+ dataset:
466
+ type: fever
467
+ name: MTEB FEVER
468
+ config: default
469
+ split: test
470
+ revision: None
471
+ metrics:
472
+ - type: map_at_1
473
+ value: 40.088
474
+ - type: map_at_10
475
+ value: 52.69200000000001
476
+ - type: map_at_100
477
+ value: 53.296
478
+ - type: map_at_1000
479
+ value: 53.325
480
+ - type: map_at_3
481
+ value: 49.905
482
+ - type: map_at_5
483
+ value: 51.617000000000004
484
+ - type: mrr_at_1
485
+ value: 43.009
486
+ - type: mrr_at_10
487
+ value: 56.203
488
+ - type: mrr_at_100
489
+ value: 56.75
490
+ - type: mrr_at_1000
491
+ value: 56.769000000000005
492
+ - type: mrr_at_3
493
+ value: 53.400000000000006
494
+ - type: mrr_at_5
495
+ value: 55.163
496
+ - type: ndcg_at_1
497
+ value: 43.009
498
+ - type: ndcg_at_10
499
+ value: 59.39
500
+ - type: ndcg_at_100
501
+ value: 62.129999999999995
502
+ - type: ndcg_at_1000
503
+ value: 62.793
504
+ - type: ndcg_at_3
505
+ value: 53.878
506
+ - type: ndcg_at_5
507
+ value: 56.887
508
+ - type: precision_at_1
509
+ value: 43.009
510
+ - type: precision_at_10
511
+ value: 8.366
512
+ - type: precision_at_100
513
+ value: 0.983
514
+ - type: precision_at_1000
515
+ value: 0.105
516
+ - type: precision_at_3
517
+ value: 22.377
518
+ - type: precision_at_5
519
+ value: 15.035000000000002
520
+ - type: recall_at_1
521
+ value: 40.088
522
+ - type: recall_at_10
523
+ value: 76.68700000000001
524
+ - type: recall_at_100
525
+ value: 88.91
526
+ - type: recall_at_1000
527
+ value: 93.782
528
+ - type: recall_at_3
529
+ value: 61.809999999999995
530
+ - type: recall_at_5
531
+ value: 69.131
532
+ - task:
533
+ type: Retrieval
534
+ dataset:
535
+ type: fiqa
536
+ name: MTEB FiQA2018
537
+ config: default
538
+ split: test
539
+ revision: None
540
+ metrics:
541
+ - type: map_at_1
542
+ value: 10.817
543
+ - type: map_at_10
544
+ value: 18.9
545
+ - type: map_at_100
546
+ value: 20.448
547
+ - type: map_at_1000
548
+ value: 20.660999999999998
549
+ - type: map_at_3
550
+ value: 15.979
551
+ - type: map_at_5
552
+ value: 17.415
553
+ - type: mrr_at_1
554
+ value: 23.148
555
+ - type: mrr_at_10
556
+ value: 31.208000000000002
557
+ - type: mrr_at_100
558
+ value: 32.167
559
+ - type: mrr_at_1000
560
+ value: 32.242
561
+ - type: mrr_at_3
562
+ value: 28.498
563
+ - type: mrr_at_5
564
+ value: 29.964000000000002
565
+ - type: ndcg_at_1
566
+ value: 23.148
567
+ - type: ndcg_at_10
568
+ value: 25.325999999999997
569
+ - type: ndcg_at_100
570
+ value: 31.927
571
+ - type: ndcg_at_1000
572
+ value: 36.081
573
+ - type: ndcg_at_3
574
+ value: 21.647
575
+ - type: ndcg_at_5
576
+ value: 22.762999999999998
577
+ - type: precision_at_1
578
+ value: 23.148
579
+ - type: precision_at_10
580
+ value: 7.546
581
+ - type: precision_at_100
582
+ value: 1.415
583
+ - type: precision_at_1000
584
+ value: 0.216
585
+ - type: precision_at_3
586
+ value: 14.969
587
+ - type: precision_at_5
588
+ value: 11.327
589
+ - type: recall_at_1
590
+ value: 10.817
591
+ - type: recall_at_10
592
+ value: 32.164
593
+ - type: recall_at_100
594
+ value: 57.655
595
+ - type: recall_at_1000
596
+ value: 82.797
597
+ - type: recall_at_3
598
+ value: 19.709
599
+ - type: recall_at_5
600
+ value: 24.333
601
+ - task:
602
+ type: Retrieval
603
+ dataset:
604
+ type: hotpotqa
605
+ name: MTEB HotpotQA
606
+ config: default
607
+ split: test
608
+ revision: None
609
+ metrics:
610
+ - type: map_at_1
611
+ value: 25.380999999999997
612
+ - type: map_at_10
613
+ value: 33.14
614
+ - type: map_at_100
615
+ value: 33.948
616
+ - type: map_at_1000
617
+ value: 34.028000000000006
618
+ - type: map_at_3
619
+ value: 31.019999999999996
620
+ - type: map_at_5
621
+ value: 32.23
622
+ - type: mrr_at_1
623
+ value: 50.763000000000005
624
+ - type: mrr_at_10
625
+ value: 57.899
626
+ - type: mrr_at_100
627
+ value: 58.426
628
+ - type: mrr_at_1000
629
+ value: 58.457
630
+ - type: mrr_at_3
631
+ value: 56.093
632
+ - type: mrr_at_5
633
+ value: 57.116
634
+ - type: ndcg_at_1
635
+ value: 50.763000000000005
636
+ - type: ndcg_at_10
637
+ value: 41.656
638
+ - type: ndcg_at_100
639
+ value: 45.079
640
+ - type: ndcg_at_1000
641
+ value: 46.916999999999994
642
+ - type: ndcg_at_3
643
+ value: 37.834
644
+ - type: ndcg_at_5
645
+ value: 39.732
646
+ - type: precision_at_1
647
+ value: 50.763000000000005
648
+ - type: precision_at_10
649
+ value: 8.648
650
+ - type: precision_at_100
651
+ value: 1.135
652
+ - type: precision_at_1000
653
+ value: 0.13799999999999998
654
+ - type: precision_at_3
655
+ value: 23.105999999999998
656
+ - type: precision_at_5
657
+ value: 15.363
658
+ - type: recall_at_1
659
+ value: 25.380999999999997
660
+ - type: recall_at_10
661
+ value: 43.241
662
+ - type: recall_at_100
663
+ value: 56.745000000000005
664
+ - type: recall_at_1000
665
+ value: 69.048
666
+ - type: recall_at_3
667
+ value: 34.659
668
+ - type: recall_at_5
669
+ value: 38.406
670
+ - task:
671
+ type: Classification
672
+ dataset:
673
+ type: mteb/imdb
674
+ name: MTEB ImdbClassification
675
+ config: default
676
+ split: test
677
+ revision: 3d86128a09e091d6018b6d26cad27f2739fc2db7
678
+ metrics:
679
+ - type: accuracy
680
+ value: 79.544
681
+ - type: ap
682
+ value: 73.82920133396664
683
+ - type: f1
684
+ value: 79.51048124883265
685
+ - task:
686
+ type: Retrieval
687
+ dataset:
688
+ type: msmarco
689
+ name: MTEB MSMARCO
690
+ config: default
691
+ split: dev
692
+ revision: None
693
+ metrics:
694
+ - type: map_at_1
695
+ value: 11.174000000000001
696
+ - type: map_at_10
697
+ value: 19.451999999999998
698
+ - type: map_at_100
699
+ value: 20.612
700
+ - type: map_at_1000
701
+ value: 20.703
702
+ - type: map_at_3
703
+ value: 16.444
704
+ - type: map_at_5
705
+ value: 18.083
706
+ - type: mrr_at_1
707
+ value: 11.447000000000001
708
+ - type: mrr_at_10
709
+ value: 19.808
710
+ - type: mrr_at_100
711
+ value: 20.958
712
+ - type: mrr_at_1000
713
+ value: 21.041999999999998
714
+ - type: mrr_at_3
715
+ value: 16.791
716
+ - type: mrr_at_5
717
+ value: 18.459
718
+ - type: ndcg_at_1
719
+ value: 11.447000000000001
720
+ - type: ndcg_at_10
721
+ value: 24.556
722
+ - type: ndcg_at_100
723
+ value: 30.637999999999998
724
+ - type: ndcg_at_1000
725
+ value: 33.14
726
+ - type: ndcg_at_3
727
+ value: 18.325
728
+ - type: ndcg_at_5
729
+ value: 21.278
730
+ - type: precision_at_1
731
+ value: 11.447000000000001
732
+ - type: precision_at_10
733
+ value: 4.215
734
+ - type: precision_at_100
735
+ value: 0.732
736
+ - type: precision_at_1000
737
+ value: 0.095
738
+ - type: precision_at_3
739
+ value: 8.052
740
+ - type: precision_at_5
741
+ value: 6.318
742
+ - type: recall_at_1
743
+ value: 11.174000000000001
744
+ - type: recall_at_10
745
+ value: 40.543
746
+ - type: recall_at_100
747
+ value: 69.699
748
+ - type: recall_at_1000
749
+ value: 89.403
750
+ - type: recall_at_3
751
+ value: 23.442
752
+ - type: recall_at_5
753
+ value: 30.536
754
+ - task:
755
+ type: Classification
756
+ dataset:
757
+ type: mteb/mtop_domain
758
+ name: MTEB MTOPDomainClassification (en)
759
+ config: en
760
+ split: test
761
+ revision: d80d48c1eb48d3562165c59d59d0034df9fff0bf
762
+ metrics:
763
+ - type: accuracy
764
+ value: 89.6671226630187
765
+ - type: f1
766
+ value: 89.57660424361246
767
+ - task:
768
+ type: Classification
769
+ dataset:
770
+ type: mteb/mtop_intent
771
+ name: MTEB MTOPIntentClassification (en)
772
+ config: en
773
+ split: test
774
+ revision: ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba
775
+ metrics:
776
+ - type: accuracy
777
+ value: 60.284997720018254
778
+ - type: f1
779
+ value: 40.30637400152823
780
+ - task:
781
+ type: Classification
782
+ dataset:
783
+ type: mteb/amazon_massive_intent
784
+ name: MTEB MassiveIntentClassification (en)
785
+ config: en
786
+ split: test
787
+ revision: 31efe3c427b0bae9c22cbb560b8f15491cc6bed7
788
+ metrics:
789
+ - type: accuracy
790
+ value: 63.33557498318763
791
+ - type: f1
792
+ value: 60.24039910680179
793
+ - task:
794
+ type: Classification
795
+ dataset:
796
+ type: mteb/amazon_massive_scenario
797
+ name: MTEB MassiveScenarioClassification (en)
798
+ config: en
799
+ split: test
800
+ revision: 7d571f92784cd94a019292a1f45445077d0ef634
801
+ metrics:
802
+ - type: accuracy
803
+ value: 72.37390719569603
804
+ - type: f1
805
+ value: 72.33097333477316
806
+ - task:
807
+ type: Clustering
808
+ dataset:
809
+ type: mteb/medrxiv-clustering-p2p
810
+ name: MTEB MedrxivClusteringP2P
811
+ config: default
812
+ split: test
813
+ revision: e7a26af6f3ae46b30dde8737f02c07b1505bcc73
814
+ metrics:
815
+ - type: v_measure
816
+ value: 34.68158939060552
817
+ - task:
818
+ type: Clustering
819
+ dataset:
820
+ type: mteb/medrxiv-clustering-s2s
821
+ name: MTEB MedrxivClusteringS2S
822
+ config: default
823
+ split: test
824
+ revision: 35191c8c0dca72d8ff3efcd72aa802307d469663
825
+ metrics:
826
+ - type: v_measure
827
+ value: 30.340061711905236
828
+ - task:
829
+ type: Reranking
830
+ dataset:
831
+ type: mteb/mind_small
832
+ name: MTEB MindSmallReranking
833
+ config: default
834
+ split: test
835
+ revision: 3bdac13927fdc888b903db93b2ffdbd90b295a69
836
+ metrics:
837
+ - type: map
838
+ value: 32.01814326295803
839
+ - type: mrr
840
+ value: 33.20555240055367
841
+ - task:
842
+ type: Retrieval
843
+ dataset:
844
+ type: nfcorpus
845
+ name: MTEB NFCorpus
846
+ config: default
847
+ split: test
848
+ revision: None
849
+ metrics:
850
+ - type: map_at_1
851
+ value: 3.3910000000000005
852
+ - type: map_at_10
853
+ value: 7.7219999999999995
854
+ - type: map_at_100
855
+ value: 10.286
856
+ - type: map_at_1000
857
+ value: 11.668000000000001
858
+ - type: map_at_3
859
+ value: 5.552
860
+ - type: map_at_5
861
+ value: 6.468
862
+ - type: mrr_at_1
863
+ value: 34.365
864
+ - type: mrr_at_10
865
+ value: 42.555
866
+ - type: mrr_at_100
867
+ value: 43.295
868
+ - type: mrr_at_1000
869
+ value: 43.357
870
+ - type: mrr_at_3
871
+ value: 40.299
872
+ - type: mrr_at_5
873
+ value: 41.182
874
+ - type: ndcg_at_1
875
+ value: 31.424000000000003
876
+ - type: ndcg_at_10
877
+ value: 24.758
878
+ - type: ndcg_at_100
879
+ value: 23.677999999999997
880
+ - type: ndcg_at_1000
881
+ value: 33.377
882
+ - type: ndcg_at_3
883
+ value: 28.302
884
+ - type: ndcg_at_5
885
+ value: 26.342
886
+ - type: precision_at_1
887
+ value: 33.437
888
+ - type: precision_at_10
889
+ value: 19.256999999999998
890
+ - type: precision_at_100
891
+ value: 6.662999999999999
892
+ - type: precision_at_1000
893
+ value: 1.9900000000000002
894
+ - type: precision_at_3
895
+ value: 27.761000000000003
896
+ - type: precision_at_5
897
+ value: 23.715
898
+ - type: recall_at_1
899
+ value: 3.3910000000000005
900
+ - type: recall_at_10
901
+ value: 11.068
902
+ - type: recall_at_100
903
+ value: 25.878
904
+ - type: recall_at_1000
905
+ value: 60.19
906
+ - type: recall_at_3
907
+ value: 6.1690000000000005
908
+ - type: recall_at_5
909
+ value: 7.767
910
+ - task:
911
+ type: Retrieval
912
+ dataset:
913
+ type: nq
914
+ name: MTEB NQ
915
+ config: default
916
+ split: test
917
+ revision: None
918
+ metrics:
919
+ - type: map_at_1
920
+ value: 15.168000000000001
921
+ - type: map_at_10
922
+ value: 26.177
923
+ - type: map_at_100
924
+ value: 27.564
925
+ - type: map_at_1000
926
+ value: 27.628999999999998
927
+ - type: map_at_3
928
+ value: 22.03
929
+ - type: map_at_5
930
+ value: 24.276
931
+ - type: mrr_at_1
932
+ value: 17.439
933
+ - type: mrr_at_10
934
+ value: 28.205000000000002
935
+ - type: mrr_at_100
936
+ value: 29.357
937
+ - type: mrr_at_1000
938
+ value: 29.408
939
+ - type: mrr_at_3
940
+ value: 24.377
941
+ - type: mrr_at_5
942
+ value: 26.540000000000003
943
+ - type: ndcg_at_1
944
+ value: 17.41
945
+ - type: ndcg_at_10
946
+ value: 32.936
947
+ - type: ndcg_at_100
948
+ value: 39.196999999999996
949
+ - type: ndcg_at_1000
950
+ value: 40.892
951
+ - type: ndcg_at_3
952
+ value: 24.721
953
+ - type: ndcg_at_5
954
+ value: 28.615000000000002
955
+ - type: precision_at_1
956
+ value: 17.41
957
+ - type: precision_at_10
958
+ value: 6.199000000000001
959
+ - type: precision_at_100
960
+ value: 0.9690000000000001
961
+ - type: precision_at_1000
962
+ value: 0.11299999999999999
963
+ - type: precision_at_3
964
+ value: 11.790000000000001
965
+ - type: precision_at_5
966
+ value: 9.264
967
+ - type: recall_at_1
968
+ value: 15.168000000000001
969
+ - type: recall_at_10
970
+ value: 51.914
971
+ - type: recall_at_100
972
+ value: 79.804
973
+ - type: recall_at_1000
974
+ value: 92.75999999999999
975
+ - type: recall_at_3
976
+ value: 30.212
977
+ - type: recall_at_5
978
+ value: 39.204
979
+ - task:
980
+ type: Retrieval
981
+ dataset:
982
+ type: quora
983
+ name: MTEB QuoraRetrieval
984
+ config: default
985
+ split: test
986
+ revision: None
987
+ metrics:
988
+ - type: map_at_1
989
+ value: 67.306
990
+ - type: map_at_10
991
+ value: 80.634
992
+ - type: map_at_100
993
+ value: 81.349
994
+ - type: map_at_1000
995
+ value: 81.37299999999999
996
+ - type: map_at_3
997
+ value: 77.691
998
+ - type: map_at_5
999
+ value: 79.512
1000
+ - type: mrr_at_1
1001
+ value: 77.56
1002
+ - type: mrr_at_10
1003
+ value: 84.177
1004
+ - type: mrr_at_100
1005
+ value: 84.35000000000001
1006
+ - type: mrr_at_1000
1007
+ value: 84.353
1008
+ - type: mrr_at_3
1009
+ value: 83.003
1010
+ - type: mrr_at_5
1011
+ value: 83.799
1012
+ - type: ndcg_at_1
1013
+ value: 77.58
1014
+ - type: ndcg_at_10
1015
+ value: 84.782
1016
+ - type: ndcg_at_100
1017
+ value: 86.443
1018
+ - type: ndcg_at_1000
1019
+ value: 86.654
1020
+ - type: ndcg_at_3
1021
+ value: 81.67
1022
+ - type: ndcg_at_5
1023
+ value: 83.356
1024
+ - type: precision_at_1
1025
+ value: 77.58
1026
+ - type: precision_at_10
1027
+ value: 12.875
1028
+ - type: precision_at_100
1029
+ value: 1.503
1030
+ - type: precision_at_1000
1031
+ value: 0.156
1032
+ - type: precision_at_3
1033
+ value: 35.63
1034
+ - type: precision_at_5
1035
+ value: 23.483999999999998
1036
+ - type: recall_at_1
1037
+ value: 67.306
1038
+ - type: recall_at_10
1039
+ value: 92.64
1040
+ - type: recall_at_100
1041
+ value: 98.681
1042
+ - type: recall_at_1000
1043
+ value: 99.79
1044
+ - type: recall_at_3
1045
+ value: 83.682
1046
+ - type: recall_at_5
1047
+ value: 88.424
1048
+ - task:
1049
+ type: Clustering
1050
+ dataset:
1051
+ type: mteb/reddit-clustering
1052
+ name: MTEB RedditClustering
1053
+ config: default
1054
+ split: test
1055
+ revision: 24640382cdbf8abc73003fb0fa6d111a705499eb
1056
+ metrics:
1057
+ - type: v_measure
1058
+ value: 50.76319866126382
1059
+ - task:
1060
+ type: Clustering
1061
+ dataset:
1062
+ type: mteb/reddit-clustering-p2p
1063
+ name: MTEB RedditClusteringP2P
1064
+ config: default
1065
+ split: test
1066
+ revision: 282350215ef01743dc01b456c7f5241fa8937f16
1067
+ metrics:
1068
+ - type: v_measure
1069
+ value: 55.024711941648995
1070
+ - task:
1071
+ type: Retrieval
1072
+ dataset:
1073
+ type: scidocs
1074
+ name: MTEB SCIDOCS
1075
+ config: default
1076
+ split: test
1077
+ revision: None
1078
+ metrics:
1079
+ - type: map_at_1
1080
+ value: 3.9379999999999997
1081
+ - type: map_at_10
1082
+ value: 8.817
1083
+ - type: map_at_100
1084
+ value: 10.546999999999999
1085
+ - type: map_at_1000
1086
+ value: 10.852
1087
+ - type: map_at_3
1088
+ value: 6.351999999999999
1089
+ - type: map_at_5
1090
+ value: 7.453
1091
+ - type: mrr_at_1
1092
+ value: 19.400000000000002
1093
+ - type: mrr_at_10
1094
+ value: 27.371000000000002
1095
+ - type: mrr_at_100
1096
+ value: 28.671999999999997
1097
+ - type: mrr_at_1000
1098
+ value: 28.747
1099
+ - type: mrr_at_3
1100
+ value: 24.583
1101
+ - type: mrr_at_5
1102
+ value: 26.143
1103
+ - type: ndcg_at_1
1104
+ value: 19.400000000000002
1105
+ - type: ndcg_at_10
1106
+ value: 15.264
1107
+ - type: ndcg_at_100
1108
+ value: 22.63
1109
+ - type: ndcg_at_1000
1110
+ value: 28.559
1111
+ - type: ndcg_at_3
1112
+ value: 14.424999999999999
1113
+ - type: ndcg_at_5
1114
+ value: 12.520000000000001
1115
+ - type: precision_at_1
1116
+ value: 19.400000000000002
1117
+ - type: precision_at_10
1118
+ value: 7.8100000000000005
1119
+ - type: precision_at_100
1120
+ value: 1.854
1121
+ - type: precision_at_1000
1122
+ value: 0.329
1123
+ - type: precision_at_3
1124
+ value: 13.100000000000001
1125
+ - type: precision_at_5
1126
+ value: 10.68
1127
+ - type: recall_at_1
1128
+ value: 3.9379999999999997
1129
+ - type: recall_at_10
1130
+ value: 15.903
1131
+ - type: recall_at_100
1132
+ value: 37.645
1133
+ - type: recall_at_1000
1134
+ value: 66.86
1135
+ - type: recall_at_3
1136
+ value: 7.993
1137
+ - type: recall_at_5
1138
+ value: 10.885
1139
+ - task:
1140
+ type: STS
1141
+ dataset:
1142
+ type: mteb/sickr-sts
1143
+ name: MTEB SICK-R
1144
+ config: default
1145
+ split: test
1146
+ revision: a6ea5a8cab320b040a23452cc28066d9beae2cee
1147
+ metrics:
1148
+ - type: cos_sim_pearson
1149
+ value: 80.12689060151425
1150
+ - type: cos_sim_spearman
1151
+ value: 70.46515535094771
1152
+ - type: euclidean_pearson
1153
+ value: 77.17160003557223
1154
+ - type: euclidean_spearman
1155
+ value: 70.4651757047438
1156
+ - type: manhattan_pearson
1157
+ value: 77.18129609281937
1158
+ - type: manhattan_spearman
1159
+ value: 70.46610403752913
1160
+ - task:
1161
+ type: STS
1162
+ dataset:
1163
+ type: mteb/sts12-sts
1164
+ name: MTEB STS12
1165
+ config: default
1166
+ split: test
1167
+ revision: a0d554a64d88156834ff5ae9920b964011b16384
1168
+ metrics:
1169
+ - type: cos_sim_pearson
1170
+ value: 70.451157033355
1171
+ - type: cos_sim_spearman
1172
+ value: 63.99899601697852
1173
+ - type: euclidean_pearson
1174
+ value: 67.46985359967678
1175
+ - type: euclidean_spearman
1176
+ value: 64.00001637764805
1177
+ - type: manhattan_pearson
1178
+ value: 67.56534741780037
1179
+ - type: manhattan_spearman
1180
+ value: 64.06533893575366
1181
+ - task:
1182
+ type: STS
1183
+ dataset:
1184
+ type: mteb/sts13-sts
1185
+ name: MTEB STS13
1186
+ config: default
1187
+ split: test
1188
+ revision: 7e90230a92c190f1bf69ae9002b8cea547a64cca
1189
+ metrics:
1190
+ - type: cos_sim_pearson
1191
+ value: 77.65086614464292
1192
+ - type: cos_sim_spearman
1193
+ value: 78.20169706921848
1194
+ - type: euclidean_pearson
1195
+ value: 77.77758172155283
1196
+ - type: euclidean_spearman
1197
+ value: 78.20169706921848
1198
+ - type: manhattan_pearson
1199
+ value: 77.75077884860052
1200
+ - type: manhattan_spearman
1201
+ value: 78.16875216484164
1202
+ - task:
1203
+ type: STS
1204
+ dataset:
1205
+ type: mteb/sts14-sts
1206
+ name: MTEB STS14
1207
+ config: default
1208
+ split: test
1209
+ revision: 6031580fec1f6af667f0bd2da0a551cf4f0b2375
1210
+ metrics:
1211
+ - type: cos_sim_pearson
1212
+ value: 76.26381598259717
1213
+ - type: cos_sim_spearman
1214
+ value: 70.78377709313477
1215
+ - type: euclidean_pearson
1216
+ value: 74.82646556532096
1217
+ - type: euclidean_spearman
1218
+ value: 70.78377658155212
1219
+ - type: manhattan_pearson
1220
+ value: 74.81784766108225
1221
+ - type: manhattan_spearman
1222
+ value: 70.79351454692176
1223
+ - task:
1224
+ type: STS
1225
+ dataset:
1226
+ type: mteb/sts15-sts
1227
+ name: MTEB STS15
1228
+ config: default
1229
+ split: test
1230
+ revision: ae752c7c21bf194d8b67fd573edf7ae58183cbe3
1231
+ metrics:
1232
+ - type: cos_sim_pearson
1233
+ value: 79.00532026789739
1234
+ - type: cos_sim_spearman
1235
+ value: 80.02708383244838
1236
+ - type: euclidean_pearson
1237
+ value: 79.48345422610525
1238
+ - type: euclidean_spearman
1239
+ value: 80.02708383244838
1240
+ - type: manhattan_pearson
1241
+ value: 79.44519739854803
1242
+ - type: manhattan_spearman
1243
+ value: 79.98344094559687
1244
+ - task:
1245
+ type: STS
1246
+ dataset:
1247
+ type: mteb/sts16-sts
1248
+ name: MTEB STS16
1249
+ config: default
1250
+ split: test
1251
+ revision: 4d8694f8f0e0100860b497b999b3dbed754a0513
1252
+ metrics:
1253
+ - type: cos_sim_pearson
1254
+ value: 77.32783048164805
1255
+ - type: cos_sim_spearman
1256
+ value: 78.79729961288045
1257
+ - type: euclidean_pearson
1258
+ value: 78.72111945793154
1259
+ - type: euclidean_spearman
1260
+ value: 78.79729904606872
1261
+ - type: manhattan_pearson
1262
+ value: 78.72464311117116
1263
+ - type: manhattan_spearman
1264
+ value: 78.822591248334
1265
+ - task:
1266
+ type: STS
1267
+ dataset:
1268
+ type: mteb/sts17-crosslingual-sts
1269
+ name: MTEB STS17 (en-en)
1270
+ config: en-en
1271
+ split: test
1272
+ revision: af5e6fb845001ecf41f4c1e033ce921939a2a68d
1273
+ metrics:
1274
+ - type: cos_sim_pearson
1275
+ value: 82.04318630630854
1276
+ - type: cos_sim_spearman
1277
+ value: 83.87886389259836
1278
+ - type: euclidean_pearson
1279
+ value: 83.40385877895086
1280
+ - type: euclidean_spearman
1281
+ value: 83.87886389259836
1282
+ - type: manhattan_pearson
1283
+ value: 83.46337128901547
1284
+ - type: manhattan_spearman
1285
+ value: 83.9723106941644
1286
+ - task:
1287
+ type: STS
1288
+ dataset:
1289
+ type: mteb/sts22-crosslingual-sts
1290
+ name: MTEB STS22 (en)
1291
+ config: en
1292
+ split: test
1293
+ revision: 6d1ba47164174a496b7fa5d3569dae26a6813b80
1294
+ metrics:
1295
+ - type: cos_sim_pearson
1296
+ value: 63.003511169944595
1297
+ - type: cos_sim_spearman
1298
+ value: 64.39318805580227
1299
+ - type: euclidean_pearson
1300
+ value: 65.4797990735967
1301
+ - type: euclidean_spearman
1302
+ value: 64.39318805580227
1303
+ - type: manhattan_pearson
1304
+ value: 65.44604544280844
1305
+ - type: manhattan_spearman
1306
+ value: 64.38742899984233
1307
+ - task:
1308
+ type: STS
1309
+ dataset:
1310
+ type: mteb/stsbenchmark-sts
1311
+ name: MTEB STSBenchmark
1312
+ config: default
1313
+ split: test
1314
+ revision: b0fddb56ed78048fa8b90373c8a3cfc37b684831
1315
+ metrics:
1316
+ - type: cos_sim_pearson
1317
+ value: 76.63101237585029
1318
+ - type: cos_sim_spearman
1319
+ value: 75.57446967644269
1320
+ - type: euclidean_pearson
1321
+ value: 76.93491768734478
1322
+ - type: euclidean_spearman
1323
+ value: 75.57446967644269
1324
+ - type: manhattan_pearson
1325
+ value: 76.92187567800636
1326
+ - type: manhattan_spearman
1327
+ value: 75.57239337194585
1328
+ - task:
1329
+ type: Reranking
1330
+ dataset:
1331
+ type: mteb/scidocs-reranking
1332
+ name: MTEB SciDocsRR
1333
+ config: default
1334
+ split: test
1335
+ revision: d3c5e1fc0b855ab6097bf1cda04dd73947d7caab
1336
+ metrics:
1337
+ - type: map
1338
+ value: 78.5376604868993
1339
+ - type: mrr
1340
+ value: 92.94422897364073
1341
+ - task:
1342
+ type: Retrieval
1343
+ dataset:
1344
+ type: scifact
1345
+ name: MTEB SciFact
1346
+ config: default
1347
+ split: test
1348
+ revision: None
1349
+ metrics:
1350
+ - type: map_at_1
1351
+ value: 38.872
1352
+ - type: map_at_10
1353
+ value: 50.417
1354
+ - type: map_at_100
1355
+ value: 51.202000000000005
1356
+ - type: map_at_1000
1357
+ value: 51.25999999999999
1358
+ - type: map_at_3
1359
+ value: 47.02
1360
+ - type: map_at_5
1361
+ value: 49.326
1362
+ - type: mrr_at_1
1363
+ value: 41.0
1364
+ - type: mrr_at_10
1365
+ value: 51.674
1366
+ - type: mrr_at_100
1367
+ value: 52.32599999999999
1368
+ - type: mrr_at_1000
1369
+ value: 52.376999999999995
1370
+ - type: mrr_at_3
1371
+ value: 48.778
1372
+ - type: mrr_at_5
1373
+ value: 50.744
1374
+ - type: ndcg_at_1
1375
+ value: 41.0
1376
+ - type: ndcg_at_10
1377
+ value: 56.027
1378
+ - type: ndcg_at_100
1379
+ value: 59.362
1380
+ - type: ndcg_at_1000
1381
+ value: 60.839
1382
+ - type: ndcg_at_3
1383
+ value: 50.019999999999996
1384
+ - type: ndcg_at_5
1385
+ value: 53.644999999999996
1386
+ - type: precision_at_1
1387
+ value: 41.0
1388
+ - type: precision_at_10
1389
+ value: 8.1
1390
+ - type: precision_at_100
1391
+ value: 0.987
1392
+ - type: precision_at_1000
1393
+ value: 0.11100000000000002
1394
+ - type: precision_at_3
1395
+ value: 20.444000000000003
1396
+ - type: precision_at_5
1397
+ value: 14.466999999999999
1398
+ - type: recall_at_1
1399
+ value: 38.872
1400
+ - type: recall_at_10
1401
+ value: 71.906
1402
+ - type: recall_at_100
1403
+ value: 86.367
1404
+ - type: recall_at_1000
1405
+ value: 98.0
1406
+ - type: recall_at_3
1407
+ value: 56.206
1408
+ - type: recall_at_5
1409
+ value: 65.05
1410
+ - task:
1411
+ type: PairClassification
1412
+ dataset:
1413
+ type: mteb/sprintduplicatequestions-pairclassification
1414
+ name: MTEB SprintDuplicateQuestions
1415
+ config: default
1416
+ split: test
1417
+ revision: d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46
1418
+ metrics:
1419
+ - type: cos_sim_accuracy
1420
+ value: 99.7039603960396
1421
+ - type: cos_sim_ap
1422
+ value: 90.40809844250262
1423
+ - type: cos_sim_f1
1424
+ value: 84.53181583031557
1425
+ - type: cos_sim_precision
1426
+ value: 87.56698821007502
1427
+ - type: cos_sim_recall
1428
+ value: 81.69999999999999
1429
+ - type: dot_accuracy
1430
+ value: 99.7039603960396
1431
+ - type: dot_ap
1432
+ value: 90.40809844250262
1433
+ - type: dot_f1
1434
+ value: 84.53181583031557
1435
+ - type: dot_precision
1436
+ value: 87.56698821007502
1437
+ - type: dot_recall
1438
+ value: 81.69999999999999
1439
+ - type: euclidean_accuracy
1440
+ value: 99.7039603960396
1441
+ - type: euclidean_ap
1442
+ value: 90.4080982863383
1443
+ - type: euclidean_f1
1444
+ value: 84.53181583031557
1445
+ - type: euclidean_precision
1446
+ value: 87.56698821007502
1447
+ - type: euclidean_recall
1448
+ value: 81.69999999999999
1449
+ - type: manhattan_accuracy
1450
+ value: 99.7
1451
+ - type: manhattan_ap
1452
+ value: 90.39771161966652
1453
+ - type: manhattan_f1
1454
+ value: 84.32989690721648
1455
+ - type: manhattan_precision
1456
+ value: 87.02127659574468
1457
+ - type: manhattan_recall
1458
+ value: 81.8
1459
+ - type: max_accuracy
1460
+ value: 99.7039603960396
1461
+ - type: max_ap
1462
+ value: 90.40809844250262
1463
+ - type: max_f1
1464
+ value: 84.53181583031557
1465
+ - task:
1466
+ type: Clustering
1467
+ dataset:
1468
+ type: mteb/stackexchange-clustering
1469
+ name: MTEB StackExchangeClustering
1470
+ config: default
1471
+ split: test
1472
+ revision: 6cbc1f7b2bc0622f2e39d2c77fa502909748c259
1473
+ metrics:
1474
+ - type: v_measure
1475
+ value: 59.663210666678715
1476
+ - task:
1477
+ type: Clustering
1478
+ dataset:
1479
+ type: mteb/stackexchange-clustering-p2p
1480
+ name: MTEB StackExchangeClusteringP2P
1481
+ config: default
1482
+ split: test
1483
+ revision: 815ca46b2622cec33ccafc3735d572c266efdb44
1484
+ metrics:
1485
+ - type: v_measure
1486
+ value: 32.107791216468776
1487
+ - task:
1488
+ type: Reranking
1489
+ dataset:
1490
+ type: mteb/stackoverflowdupquestions-reranking
1491
+ name: MTEB StackOverflowDupQuestions
1492
+ config: default
1493
+ split: test
1494
+ revision: e185fbe320c72810689fc5848eb6114e1ef5ec69
1495
+ metrics:
1496
+ - type: map
1497
+ value: 46.440691925067604
1498
+ - type: mrr
1499
+ value: 47.03390257618199
1500
+ - task:
1501
+ type: Summarization
1502
+ dataset:
1503
+ type: mteb/summeval
1504
+ name: MTEB SummEval
1505
+ config: default
1506
+ split: test
1507
+ revision: cda12ad7615edc362dbf25a00fdd61d3b1eaf93c
1508
+ metrics:
1509
+ - type: cos_sim_pearson
1510
+ value: 31.067177519784074
1511
+ - type: cos_sim_spearman
1512
+ value: 31.234728424648967
1513
+ - type: dot_pearson
1514
+ value: 31.06717083018107
1515
+ - type: dot_spearman
1516
+ value: 31.234728424648967
1517
+ - task:
1518
+ type: Retrieval
1519
+ dataset:
1520
+ type: trec-covid
1521
+ name: MTEB TRECCOVID
1522
+ config: default
1523
+ split: test
1524
+ revision: None
1525
+ metrics:
1526
+ - type: map_at_1
1527
+ value: 0.136
1528
+ - type: map_at_10
1529
+ value: 0.767
1530
+ - type: map_at_100
1531
+ value: 3.3689999999999998
1532
+ - type: map_at_1000
1533
+ value: 8.613999999999999
1534
+ - type: map_at_3
1535
+ value: 0.369
1536
+ - type: map_at_5
1537
+ value: 0.514
1538
+ - type: mrr_at_1
1539
+ value: 48.0
1540
+ - type: mrr_at_10
1541
+ value: 63.908
1542
+ - type: mrr_at_100
1543
+ value: 64.615
1544
+ - type: mrr_at_1000
1545
+ value: 64.615
1546
+ - type: mrr_at_3
1547
+ value: 62.0
1548
+ - type: mrr_at_5
1549
+ value: 63.4
1550
+ - type: ndcg_at_1
1551
+ value: 44.0
1552
+ - type: ndcg_at_10
1553
+ value: 38.579
1554
+ - type: ndcg_at_100
1555
+ value: 26.409
1556
+ - type: ndcg_at_1000
1557
+ value: 26.858999999999998
1558
+ - type: ndcg_at_3
1559
+ value: 47.134
1560
+ - type: ndcg_at_5
1561
+ value: 43.287
1562
+ - type: precision_at_1
1563
+ value: 48.0
1564
+ - type: precision_at_10
1565
+ value: 40.400000000000006
1566
+ - type: precision_at_100
1567
+ value: 26.640000000000004
1568
+ - type: precision_at_1000
1569
+ value: 12.04
1570
+ - type: precision_at_3
1571
+ value: 52.666999999999994
1572
+ - type: precision_at_5
1573
+ value: 46.800000000000004
1574
+ - type: recall_at_1
1575
+ value: 0.136
1576
+ - type: recall_at_10
1577
+ value: 1.0070000000000001
1578
+ - type: recall_at_100
1579
+ value: 6.318
1580
+ - type: recall_at_1000
1581
+ value: 26.522000000000002
1582
+ - type: recall_at_3
1583
+ value: 0.41700000000000004
1584
+ - type: recall_at_5
1585
+ value: 0.606
1586
+ - task:
1587
+ type: Retrieval
1588
+ dataset:
1589
+ type: webis-touche2020
1590
+ name: MTEB Touche2020
1591
+ config: default
1592
+ split: test
1593
+ revision: None
1594
+ metrics:
1595
+ - type: map_at_1
1596
+ value: 1.9949999999999999
1597
+ - type: map_at_10
1598
+ value: 8.304
1599
+ - type: map_at_100
1600
+ value: 13.644
1601
+ - type: map_at_1000
1602
+ value: 15.43
1603
+ - type: map_at_3
1604
+ value: 4.788
1605
+ - type: map_at_5
1606
+ value: 6.22
1607
+ - type: mrr_at_1
1608
+ value: 22.448999999999998
1609
+ - type: mrr_at_10
1610
+ value: 37.658
1611
+ - type: mrr_at_100
1612
+ value: 38.491
1613
+ - type: mrr_at_1000
1614
+ value: 38.503
1615
+ - type: mrr_at_3
1616
+ value: 32.312999999999995
1617
+ - type: mrr_at_5
1618
+ value: 35.68
1619
+ - type: ndcg_at_1
1620
+ value: 21.429000000000002
1621
+ - type: ndcg_at_10
1622
+ value: 18.995
1623
+ - type: ndcg_at_100
1624
+ value: 32.029999999999994
1625
+ - type: ndcg_at_1000
1626
+ value: 44.852
1627
+ - type: ndcg_at_3
1628
+ value: 19.464000000000002
1629
+ - type: ndcg_at_5
1630
+ value: 19.172
1631
+ - type: precision_at_1
1632
+ value: 22.448999999999998
1633
+ - type: precision_at_10
1634
+ value: 17.143
1635
+ - type: precision_at_100
1636
+ value: 6.877999999999999
1637
+ - type: precision_at_1000
1638
+ value: 1.524
1639
+ - type: precision_at_3
1640
+ value: 21.769
1641
+ - type: precision_at_5
1642
+ value: 20.0
1643
+ - type: recall_at_1
1644
+ value: 1.9949999999999999
1645
+ - type: recall_at_10
1646
+ value: 13.395999999999999
1647
+ - type: recall_at_100
1648
+ value: 44.348
1649
+ - type: recall_at_1000
1650
+ value: 82.622
1651
+ - type: recall_at_3
1652
+ value: 5.896
1653
+ - type: recall_at_5
1654
+ value: 8.554
1655
+ - task:
1656
+ type: Classification
1657
+ dataset:
1658
+ type: mteb/toxic_conversations_50k
1659
+ name: MTEB ToxicConversationsClassification
1660
+ config: default
1661
+ split: test
1662
+ revision: d7c0de2777da35d6aae2200a62c6e0e5af397c4c
1663
+ metrics:
1664
+ - type: accuracy
1665
+ value: 67.9394
1666
+ - type: ap
1667
+ value: 12.943337263423334
1668
+ - type: f1
1669
+ value: 52.28243093094156
1670
+ - task:
1671
+ type: Classification
1672
+ dataset:
1673
+ type: mteb/tweet_sentiment_extraction
1674
+ name: MTEB TweetSentimentExtractionClassification
1675
+ config: default
1676
+ split: test
1677
+ revision: d604517c81ca91fe16a244d1248fc021f9ecee7a
1678
+ metrics:
1679
+ - type: accuracy
1680
+ value: 56.414827391058296
1681
+ - type: f1
1682
+ value: 56.666412409573105
1683
+ - task:
1684
+ type: Clustering
1685
+ dataset:
1686
+ type: mteb/twentynewsgroups-clustering
1687
+ name: MTEB TwentyNewsgroupsClustering
1688
+ config: default
1689
+ split: test
1690
+ revision: 6125ec4e24fa026cec8a478383ee943acfbd5449
1691
+ metrics:
1692
+ - type: v_measure
1693
+ value: 47.009746255495465
1694
+ - task:
1695
+ type: PairClassification
1696
+ dataset:
1697
+ type: mteb/twittersemeval2015-pairclassification
1698
+ name: MTEB TwitterSemEval2015
1699
+ config: default
1700
+ split: test
1701
+ revision: 70970daeab8776df92f5ea462b6173c0b46fd2d1
1702
+ metrics:
1703
+ - type: cos_sim_accuracy
1704
+ value: 84.02574953805807
1705
+ - type: cos_sim_ap
1706
+ value: 67.66599910763128
1707
+ - type: cos_sim_f1
1708
+ value: 63.491277990844985
1709
+ - type: cos_sim_precision
1710
+ value: 59.77172140694154
1711
+ - type: cos_sim_recall
1712
+ value: 67.70448548812665
1713
+ - type: dot_accuracy
1714
+ value: 84.02574953805807
1715
+ - type: dot_ap
1716
+ value: 67.66600090945406
1717
+ - type: dot_f1
1718
+ value: 63.491277990844985
1719
+ - type: dot_precision
1720
+ value: 59.77172140694154
1721
+ - type: dot_recall
1722
+ value: 67.70448548812665
1723
+ - type: euclidean_accuracy
1724
+ value: 84.02574953805807
1725
+ - type: euclidean_ap
1726
+ value: 67.6659842364448
1727
+ - type: euclidean_f1
1728
+ value: 63.491277990844985
1729
+ - type: euclidean_precision
1730
+ value: 59.77172140694154
1731
+ - type: euclidean_recall
1732
+ value: 67.70448548812665
1733
+ - type: manhattan_accuracy
1734
+ value: 84.0317100792752
1735
+ - type: manhattan_ap
1736
+ value: 67.66351692448987
1737
+ - type: manhattan_f1
1738
+ value: 63.48610948306178
1739
+ - type: manhattan_precision
1740
+ value: 57.11875131828729
1741
+ - type: manhattan_recall
1742
+ value: 71.45118733509234
1743
+ - type: max_accuracy
1744
+ value: 84.0317100792752
1745
+ - type: max_ap
1746
+ value: 67.66600090945406
1747
+ - type: max_f1
1748
+ value: 63.491277990844985
1749
+ - task:
1750
+ type: PairClassification
1751
+ dataset:
1752
+ type: mteb/twitterurlcorpus-pairclassification
1753
+ name: MTEB TwitterURLCorpus
1754
+ config: default
1755
+ split: test
1756
+ revision: 8b6510b0b1fa4e4c4f879467980e9be563ec1cdf
1757
+ metrics:
1758
+ - type: cos_sim_accuracy
1759
+ value: 87.53832421314084
1760
+ - type: cos_sim_ap
1761
+ value: 83.11416594316626
1762
+ - type: cos_sim_f1
1763
+ value: 75.41118114347518
1764
+ - type: cos_sim_precision
1765
+ value: 73.12839059674504
1766
+ - type: cos_sim_recall
1767
+ value: 77.8410840776101
1768
+ - type: dot_accuracy
1769
+ value: 87.53832421314084
1770
+ - type: dot_ap
1771
+ value: 83.11416226342155
1772
+ - type: dot_f1
1773
+ value: 75.41118114347518
1774
+ - type: dot_precision
1775
+ value: 73.12839059674504
1776
+ - type: dot_recall
1777
+ value: 77.8410840776101
1778
+ - type: euclidean_accuracy
1779
+ value: 87.53832421314084
1780
+ - type: euclidean_ap
1781
+ value: 83.11416284455395
1782
+ - type: euclidean_f1
1783
+ value: 75.41118114347518
1784
+ - type: euclidean_precision
1785
+ value: 73.12839059674504
1786
+ - type: euclidean_recall
1787
+ value: 77.8410840776101
1788
+ - type: manhattan_accuracy
1789
+ value: 87.49369348391353
1790
+ - type: manhattan_ap
1791
+ value: 83.08066812574694
1792
+ - type: manhattan_f1
1793
+ value: 75.36561228603892
1794
+ - type: manhattan_precision
1795
+ value: 71.9202518363064
1796
+ - type: manhattan_recall
1797
+ value: 79.15768401601478
1798
+ - type: max_accuracy
1799
+ value: 87.53832421314084
1800
+ - type: max_ap
1801
+ value: 83.11416594316626
1802
+ - type: max_f1
1803
+ value: 75.41118114347518
1804
  ---
1805
+
1806
+ # lodestone-base-4096-v1
1807
+
1808
+ This new [sentence-transformers](https://www.SBERT.net) model from [Hum](https://www.hum.works/) maps long sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.
1809
+
1810
+ ## Abstract
1811
+
1812
+ In the hopes of furthering Hum's overarching mission of increasing the accessibility and interconnectivity of human knowledge, this model was developed as part of a project intending to boost the maximum input sequence length of sentence embedding models by leveraging recent architectural advances in the design of transformer models such as the incorporation of FlashAttention, Attention with Linear Biases (ALiBi), and Gated Linear Units (GLU). These modifications and enhancements were implemented by the team at MosaicML who designed and constructed the pre-trained [`mosaic-bert-base-seqlen-2048`](https://huggingface.co/mosaicml/mosaic-bert-base-seqlen-2048) model, and more information regarding the details of their development and testing specifications can be found on the model card.
1813
+
1814
+ While the fine-tuning procedure followed during the course of this project loosely mirrors that of the of the original [Flax-sentence-embeddings](https://huggingface.co/flax-sentence-embeddings) team responsible for the creation of many other popular sentence-transformers models (e.g. [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2), [all-distilroberta-v1](https://huggingface.co/sentence-transformers/all-distilroberta-v1), and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)), our methodology includes novel techniques for data loading, batch sampling, and model checkpointing intended to improve training efficiency with regards to memory allocation and data storage.
1815
+
1816
+ Through combining these well-established and proven fine-tuning practices with novel advances in transformer architectural elements, our `lodestone-base-4096-v1` model is able to achieve comparable performance metrics on standard text embedding evaluation benchmarks while also supporting a longer and more robust input sequence length of 4096 while retaining a smaller, more manageable size capable of being run on either a GPU or CPU.
1817
+
1818
+ ## Usage
1819
+
1820
+ Using this model becomes relatively easy when you have [sentence-transformers](https://www.SBERT.net) installed.
1821
+ *At the time of publishing, sentence-transformers does not support remote code which is required for flash-attention used by the model. A fork of the sentence-transformers repository that allows remote code execution is provided for convenience. It can be installed using the following command:*
1822
+ ```
1823
+ pip install git+https://github.com/Hum-Works/sentence-transformers.git
1824
+ ```
1825
+
1826
+ Then you can use the model like this:
1827
+ ```python
1828
+ from sentence_transformers import SentenceTransformer
1829
+
1830
+ model = SentenceTransformer('lodestone-base-4096-v1', trust_remote_code=True, revision='v1.0.0')
1831
+ sentences = ["This is an example sentence", "Each sentence is converted"]
1832
+ embeddings = model.encode(sentences)
1833
+ print(embeddings)
1834
+ ```
1835
+ *Note: The model will use the openAI/Triton implementation of FlashAttention if installed. This is more performant than the fallback, torch implementation. Some platforms and GPUs may not be supported by Triton - up to date compatibility can be found on [Triton’s github page](https://github.com/openai/triton#compatibility).*
1836
+
1837
+ ------
1838
+
1839
+ ## Background
1840
+
1841
+ The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised contrastive learning objective. We used the pretrained [`mosaic-bert-base-seqlen-2048`](https://huggingface.co/mosaicml/mosaic-bert-base-seqlen-2048) model and fine-tuned it on a nearly 1.5B sentence pairs dataset. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
1842
+
1843
+ ## Intended uses
1844
+
1845
+ Our model is intended to be used as a long sentence and paragraph encoder. Given an input text, it outputs a vector containing the semantic information. The sentence vector may be used for information retrieval, clustering, or sentence similarity tasks.
1846
+
1847
+ ## Training procedure
1848
+
1849
+ ### Pre-training
1850
+
1851
+ We use the pretrained [`mosaic-bert-base-seqlen-2048`](https://huggingface.co/mosaicml/mosaic-bert-base-seqlen-2048). Please refer to the model card for more detailed information about the pre-training procedure.
1852
+
1853
+ ### Fine-tuning
1854
+
1855
+ We fine-tune the model using a contrastive objective. Formally, we compute the dot product of each possible sentence pairing in the batch. We then apply the cross entropy loss by comparing with true pairs.
1856
+
1857
+ #### Hyperparameters
1858
+
1859
+ We trained our model on an ml.g5.4xlarge EC2 instance with 1 NVIDIA A10G Tensor Core GPU. We train the model during 1.4 million steps using a batch size of 16. We use a learning rate warm up of 500. The sequence length during training was limited to 2048 tokens. We used the AdamW optimizer with a 2e-5 learning rate and weight decay of 0.01 (i.e. the default parameter values for SentenceTransformer.fit()). The full training script is accessible in this current repository: `Training.py`.
1860
+
1861
+ ## Model Architecture
1862
+ By incorporating FlashAttention, [Attention with Linear Biases (ALiBi)](https://arxiv.org/abs/2108.12409), and Gated Linear Units (GLU), this model is able to handle input sequences of 4096, 8x longer than that supported by most comparable sentence embedding models.
1863
+ The model was trained using a sequence length maximum of 2048, but the final model has a maximum sequence length of 4096. This is accomplished by taking advantage of ALiBi’s positional attention extrapolation which has been shown to allow sequence lengths of 2x the initial trained length.
1864
+
1865
+ ## Full Model Architecture
1866
+ ```
1867
+ SentenceTransformer(
1868
+ (0): Transformer({'max_seq_length': 4096, 'do_lower_case': False}) with Transformer model: BertModel
1869
+ (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
1870
+ (2): Normalize()
1871
+ )
1872
+ ```
1873
+
1874
+ #### Training data
1875
+
1876
+ We use the concatenation from multiple datasets to fine-tune our model. The total number of sentence pairs is nearly 1.5 billion sentences. We sampled each dataset given a weighted probability proportional to its relative contribution to the entire dataset.
1877
+ The breakdown of the dataset can be seen below, and the entire dataset can be publicly accessed and uploaded via the `Dataloading.ipynb` located within this repository.
1878
+
1879
+ | Dataset | Paper | Number of training tuples |
1880
+ |--------------------------------------------------------|:----------------------------------------:|:--------------------------:|
1881
+ | [Reddit comments (2015-2018)](https://github.com/PolyAI-LDN/conversational-datasets/tree/master/reddit) | [paper](https://arxiv.org/abs/1904.06472) | 726,484,430 |
1882
+ | **[S2ORC](https://github.com/allenai/s2orc) Citation pairs (Abstracts)** | [paper](https://aclanthology.org/2020.acl-main.447/) | 252,102,397 |
1883
+ | **[Reddit posts](https://huggingface.co/datasets/sentence-transformers/reddit-title-body) (Title, Body) pairs** | - | 127,445,911 |
1884
+ | **[Amazon reviews (2018)](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (Title, Review) pairs** | - | 87,877,725 |
1885
+ | [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs | [paper](https://doi.org/10.1145/2623330.2623677) | 77,427,422 |
1886
+ | [PAQ](https://github.com/facebookresearch/PAQ) (Question, Answer) pairs | [paper](https://arxiv.org/abs/2102.07033) | 64,371,441 |
1887
+ | [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Titles) | [paper](https://aclanthology.org/2020.acl-main.447/) | 52,603,982 |
1888
+ | [S2ORC](https://github.com/allenai/s2orc) (Title, Abstract) | [paper](https://aclanthology.org/2020.acl-main.447/) | 41,769,185 |
1889
+ | [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_title_body_jsonl) (Title, Body) pairs | - | 25,368,423 |
1890
+ | [MS MARCO](https://microsoft.github.io/msmarco/) triplets | [paper](https://doi.org/10.1145/3404835.3462804) | 9,144,553 |
1891
+ | **[Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_title_best_voted_answer_jsonl) (Title, Most Upvoted Answer) pairs** | - | 4,784,250 |
1892
+ | **[Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_titlebody_best_voted_answer_jsonl) (Title+Body, Most Upvoted Answer) pairs** | - | 4,551,660 |
1893
+ | [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) | [paper](https://arxiv.org/pdf/2104.08727.pdf) | 3,012,496 |
1894
+ | **[Amazon QA](https://huggingface.co/datasets/sentence-transformers/embedding-training-data)** | - | 2,507,114 |
1895
+ | [Code Search](https://huggingface.co/datasets/code_search_net) | - | 1,375,067 |
1896
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 1,198,260 |
1897
+ | **[AG News]((Title, Description) pairs of news articles from the AG News dataset)** | - | 1,157,745 |
1898
+ | [COCO](https://cocodataset.org/#home) Image captions | [paper](https://link.springer.com/chapter/10.1007%2F978-3-319-10602-1_48) | 828,395|
1899
+ | [SPECTER](https://github.com/allenai/specter) citation triplets | [paper](https://doi.org/10.18653/v1/2020.acl-main.207) | 684,100 |
1900
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 681,164 |
1901
+ | [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 659,896 |
1902
+ | **[CC News](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (Title, article) pairs** | - | 614,664 |
1903
+ | **[NPR](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (Title, Body) pairs** | - | 594,384 |
1904
+ | [SearchQA](https://huggingface.co/datasets/search_qa) | [paper](https://arxiv.org/abs/1704.05179) | 582,261 |
1905
+ | **[MS Marco](https://microsoft.github.io/msmarco/) (Query, Answer Passage) pairs** | [paper](https://doi.org/10.1145/3404835.3462804) | 532,751 |
1906
+ | [Stack Exchange](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) (Title, Body) pairs | - | 364,000 |
1907
+ | [Eli5](https://huggingface.co/datasets/eli5) | [paper](https://doi.org/10.18653/v1/p19-1346) | 325,475 |
1908
+ | [Flickr 30k](https://shannon.cs.illinois.edu/DenotationGraph/) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/229/33) | 317,695 |
1909
+ | **[CNN & DailyMail](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (highlight sentences, article) pairs** | - | 311,971 |
1910
+ | [Stack Exchange](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) Duplicate questions (titles) | - | 304,524 |
1911
+ | AllNLI ([SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) | [paper SNLI](https://doi.org/10.18653/v1/d15-1075), [paper MultiNLI](https://doi.org/10.18653/v1/n18-1101) | 277,230 |
1912
+ | [Stack Exchange](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) Duplicate questions (bodies) | - | 250,518 |
1913
+ | [Stack Exchange](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) Duplicate questions (titles+bodies) | - | 250,459 |
1914
+ | **[XSUM](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (Summary, News Article) pairs** | - | 226,711 |
1915
+ | **[Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_titlebody_best_and_down_voted_answer_jsonl) (Title+Body, Most Upvoted Answer, Most Downvoted Answer) triplets** | - | 216,454 |
1916
+ | [Sentence Compression](https://github.com/google-research-datasets/sentence-compression) | [paper](https://www.aclweb.org/anthology/D13-1155/) | 180,000 |
1917
+ | **[FEVER](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) training data** | - | 139,051 |
1918
+ | [Wikihow](https://github.com/pvl/wikihow_pairs_dataset) | [paper](https://arxiv.org/abs/1810.09305) | 128,542 |
1919
+ | **[SearchQA](https://huggingface.co/datasets/search_qa) (Question, Top-Snippet)** | [paper](https://arxiv.org/abs/1704.05179) | 117,384 |
1920
+ | [Altlex](https://github.com/chridey/altlex/) | [paper](https://aclanthology.org/P16-1135.pdf) | 112,696 |
1921
+ | **[Quora Question Duplicates](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs)** | - | 103,663 |
1922
+ | [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) | - | 103,663 |
1923
+ | [Simple Wikipedia](https://cs.pomona.edu/~dkauchak/simplification/) | [paper](https://www.aclweb.org/anthology/P11-2117/) | 102,225 |
1924
+ | [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/1455) | 100,231 |
1925
+ | [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) | [paper](https://aclanthology.org/P18-2124.pdf) | 87,599 |
1926
+ | [TriviaQA](https://huggingface.co/datasets/trivia_qa) | - | 73,346 |
1927
+ | **Total** | | **1,492,453,113** |
1928
+
1929
+ #### Replication
1930
+
1931
+ The entire fine-tuning process for this model can be replicated by following the steps outlined in the `Replication.txt` file within this repository. This document explains how to modify the [sentence-transformers](https://www.SBERT.net) library, configure the pre-trained [`mosaic-bert-base-seqlen-2048`](https://huggingface.co/mosaicml/mosaic-bert-base-seqlen-2048) model, load all of the training data, and execute the training script.
1932
+
1933
+ #### Limitations
1934
+
1935
+ Due to technical constraints (e.g. limited GPU memory capacity), this model was trained with a smaller batch size of 16, making it so that each step during training was less well-informed than it would have been on a higher performance system. This smaller than ideal hyperparameter value will generally cause the model to be more likely to get stuck in a local minimum and for the parameter configuration to take a longer time to converge to the optimum. In order to counteract this potential risk, we trained the model for a larger number of steps than many of its contemporaries to ensure a greater chance of achieving strong performance, but this is an area which could be improved if further fine-tuning was performed.
1936
+
1937
+ It is also worth noting that, while this model is able to handle longer input sequences of up to 4096 word pieces, the training dataset used consists of sentence and paragraph pairs and triplets which do not necessarily reach that maximum sequence length. Since the data was not tailored specifically for this larger input size, further fine-tuning may be required to ensure highly accurate embeddings for longer texts of that magnitude.
1938
+
1939
+ Finally, as stated on https://huggingface.co/datasets/sentence-transformers/reddit-title-body, an additional reminder and warning regarding the Reddit posts data is that one should "Be aware that this dataset is not filtered for biases, hate-speech, spam, racial slurs etc. It depicts the content as it is posted on Reddit." Thus, while we believe this has not induced any pathological behaviors in the model's performance due to its relatively low prevalence of records in the whole dataset of nearly 1.5B sentence pairs and the fact that this model was trained to produce semantic embeddings rather than generative text outputs, it is always important to be aware of vulnerabilities to bias.
1940
+
Replication.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Lodestone Replication
2
+
3
+ The dataloading, library modification, model preparation, and training process can be replicated in a straightforward manner by simply running a few Jupyter notebooks and Python files.
4
+
5
+ Data Wrangling and Loading
6
+
7
+ Dataloading.ipynb utilizes the contents of the GoogleSheets_datasets.tsv and HuggingFace_datasets.tsv to fetch data from various URLs provided by the original distilroberta team to their curated datasets in cloud storage. The data is then streamed directly into the data folder of the lodestone-rnd S3 bucket in us-east-1. In addition to the data used by the distilroberta team and provided at https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0, data was also collected from https://huggingface.co/datasets/sentence-transformers/embedding-training-data and the following HuggingFace dataset repositories:
8
+ Stack Exchange
9
+ https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_title_body_jsonl
10
+ https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_titlebody_best_voted_answer_jsonl
11
+ https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_title_best_voted_answer_jsonl
12
+ https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_titlebody_best_and_down_voted_answer_jsonl
13
+ Reddit
14
+ https://huggingface.co/datasets/sentence-transformers/reddit-title-body
15
+ All of the HuggingFace data is handled remotely or pulled via the script in Dataloading.ipynb, so the only files required for this entire process are Dataloading.ipynb, GoogleSheets_datasets.tsv, and HuggingFace_datasets.tsv. Running this notebook results in 679 objects and 310.5GB of data being loaded into S3.
16
+
17
+ Once the data is in S3, run Data_Records.ipynb to generate the data_records.json file which contains a dictionary of {filename: record count} pairs and is used throughout the Training.py script.
18
+
19
+ Library and Model Preparation
20
+
21
+ In order to run the training process with our specific model, we need to make a few custom modifications to the sentence-transformers library and to the config.json file of the mosaic-bert-base-seqlen-2048 base model.
22
+
23
+ To alter the sentence-transformers library, clone the repository from https://github.com/UKPLab/sentence-transformers locally and replace the SentenceTransformer.py and Transformer.py files located within the sentence-transformers/sentence_transformers/ and sentence-transformers/sentence_transformers/models/ directories of the cloned repository, respectively, with those located inside dev/ folder. (This has already been done in this notebook instance, but this will have to be completed if training on another system.)
24
+
25
+ Before conducting actual training, we also need to clone the mosaic-bert-base-seqlen-2048 model locally and make a few small changes to its config.json file. Running Mosaic_Model.ipynb will execute this process and get our model ready to begin training. (Again, this has already been done in this notebook instance, but this will have to be completed if training on another system.)
26
+
27
+ Training
28
+
29
+ To perform the final training run, open a SageMaker Terminal window and execute the following:
30
+ cd SageMaker
31
+ screen -S training
32
+ python Training.py
33
+ ^a d (that is, Ctrl + a, then d)
34
+
35
+ To reattach to the screen and observe how training is progressing, run `screen -r training` in the Terminal. Occasionally epochs may stall and require manual intervention to kickstart the process again. Pressing ^c (that is, Ctrl+c) inside the screen should suffice the get things going again, but this action will automatically cause the currently stalled epoch to fail and for the training to proceed to the next epoch or data chunk without updating the existing model parameterization. Epoch successes and failures and the cumulative number of successfully completed steps can be monitored via the train_logs.txt file which is updated automatically throughout the course of training.
36
+
37
+ The Training.py file can be reconfigured such that training hyperparameters could be passed in through the command line, but, at present, hyperparameters should be set within the file before running it.
38
+
39
+ This concludes the steps required for replication of the Lodestone training process.
40
+
Training.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This training script is a duplicate of the Training.ipynb notebook but can be invoked from the terminal
2
+
3
+ import os
4
+ print(os.getcwd())
5
+ os.environ["PATH"]="/usr/local/cuda-11.7/bin:"+os.getenv("PATH")
6
+
7
+ os.system('pip uninstall -y torch')
8
+ os.system('pip uninstall -y einops')
9
+ os.system('pip uninstall -y transformers')
10
+ os.system('pip uninstall -y sentence_transformers')
11
+ os.system('pip uninstall -y datasets')
12
+ os.system('pip uninstall -y sagemaker')
13
+ os.system('pip uninstall -y smart_open')
14
+ os.system('pip uninstall -y pynvml')
15
+
16
+ os.system('pip install -r lodestone-reqs.txt')
17
+
18
+ os.system('pip install -e ./sentence-transformers')
19
+
20
+ os.system('pip uninstall -y triton')
21
+ os.system('pip install --no-deps triton==2.0.0.dev20221202')
22
+
23
+ #####
24
+
25
+ from pynvml import *
26
+ import math
27
+ from sentence_transformers import models, losses
28
+ from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
29
+ import logging
30
+ import os
31
+ import json
32
+ import torch
33
+ import boto3
34
+ from smart_open import open
35
+ import random
36
+ import time
37
+ import gc
38
+
39
+ os.environ["PATH"]="/usr/local/cuda-11.7/bin:"+os.getenv("PATH")
40
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
+
42
+ #####
43
+
44
+
45
+ def print_gpu_utilization():
46
+ "This helper function outputs the current GPU memory usage."
47
+ nvmlInit()
48
+ handle = nvmlDeviceGetHandleByIndex(0)
49
+ info = nvmlDeviceGetMemoryInfo(handle)
50
+ return f"GPU memory occupied: {info.used/1024**3} GB."
51
+
52
+ #####
53
+
54
+
55
+ class MultiDatasetDataLoader:
56
+ """
57
+ This custom dataloader class consumes a list of datasets and a batch size and produces batches randomly sampled
58
+ from the datasets provided where each batch consists of records from a single dataset and datasets are chosen
59
+ for batches in proportion to their total number of records.
60
+ """
61
+ def __init__(self, datasets, batch_size_pairs, batch_size_triplets=None, dataset_size_temp=-1, allow_swap=True):
62
+ self.allow_swap = allow_swap
63
+ self.batch_size_pairs = batch_size_pairs
64
+ self.batch_size_triplets = batch_size_pairs if batch_size_triplets is None else batch_size_triplets
65
+
66
+ # Compute dataset weights
67
+ self.dataset_lengths = list(map(len, datasets))
68
+ self.dataset_lengths_sum = sum(self.dataset_lengths)
69
+
70
+ weights = []
71
+ # if dataset_size_temp > 0: # Scale probability with dataset size
72
+ # for dataset in datasets:
73
+ # prob = len(dataset) / self.dataset_lengths_sum
74
+ # weights.append(max(1, int(math.pow(prob, 1 / dataset_size_temp) * 1000)))
75
+ # else: # Equal weighting of all datasets
76
+ # weights = [100] * len(datasets)
77
+ for dataset in datasets:
78
+ weights.append(len(dataset))
79
+
80
+ # logging.info("Dataset lengths and weights: {}".format(list(zip(self.dataset_lengths, weights))))
81
+
82
+ self.dataset_idx = []
83
+ self.dataset_idx_pointer = 0
84
+
85
+ for idx, weight in enumerate(weights):
86
+ self.dataset_idx.extend([idx] * weight)
87
+ random.shuffle(self.dataset_idx)
88
+
89
+ self.datasets = []
90
+ for dataset in datasets:
91
+ random.shuffle(dataset)
92
+ self.datasets.append({
93
+ 'elements': dataset,
94
+ 'pointer': 0,
95
+ })
96
+
97
+ def __iter__(self):
98
+ for _ in range(int(self.__len__())):
99
+ # Select dataset
100
+ if self.dataset_idx_pointer >= len(self.dataset_idx):
101
+ self.dataset_idx_pointer = 0
102
+ random.shuffle(self.dataset_idx)
103
+
104
+ dataset_idx = self.dataset_idx[self.dataset_idx_pointer]
105
+ self.dataset_idx_pointer += 1
106
+
107
+ # Select batch from this dataset
108
+ dataset = self.datasets[dataset_idx]
109
+ batch_size = self.batch_size_pairs if len(dataset['elements'][0].texts) == 2 else self.batch_size_triplets
110
+
111
+ batch = []
112
+ texts_in_batch = set()
113
+ guid_in_batch = set()
114
+ while len(batch) < batch_size:
115
+ example = dataset['elements'][dataset['pointer']]
116
+
117
+ valid_example = True
118
+ # First check if one of the texts in already in the batch
119
+ for text in example.texts:
120
+ text_norm = text.strip().lower()
121
+ if text_norm in texts_in_batch:
122
+ valid_example = False
123
+
124
+ texts_in_batch.add(text_norm)
125
+
126
+ # If the example has a label, check if label is in batch
127
+ if example.guid is not None:
128
+ valid_example = valid_example and example.guid not in guid_in_batch
129
+ guid_in_batch.add(example.guid)
130
+
131
+ if valid_example:
132
+ if self.allow_swap and random.random() > 0.5:
133
+ example.texts[0], example.texts[1] = example.texts[1], example.texts[0]
134
+
135
+ batch.append(example)
136
+
137
+ dataset['pointer'] += 1
138
+ if dataset['pointer'] >= len(dataset['elements']):
139
+ dataset['pointer'] = 0
140
+ random.shuffle(dataset['elements'])
141
+
142
+ yield self.collate_fn(batch) if self.collate_fn is not None else batch
143
+
144
+ def __len__(self):
145
+ return int(self.dataset_lengths_sum / self.batch_size_pairs)
146
+
147
+ #####
148
+
149
+
150
+ # These four classes of custom generators parse the raw data from the files in S3 and format it into InputExamples which can be properly interpreted by a SentenceTransformer model.
151
+
152
+ class RedditTitleBodyDataset:
153
+ def __init__(self, source_uri, max_seq_length):
154
+ self.source_uri = source_uri
155
+ self.s3_client = boto3.client("s3")
156
+ self.max_seq_length = max_seq_length
157
+
158
+ def __iter__(self):
159
+ while True:
160
+ for json_line in open(self.source_uri, transport_params={"client": self.s3_client}):
161
+ data_line = json.loads(json_line.strip())
162
+
163
+ if "title" in data_line and "body" in data_line:
164
+ data = {'guid': None, 'texts': [" ".join(data_line['title'].split(" ")[:self.max_seq_length]), " ".join(data_line['body'].split(" ")[:self.max_seq_length])]}
165
+ record = InputExample(guid=data.get('guid', None), texts=data['texts'])
166
+
167
+ yield record
168
+
169
+
170
+ class RedditYearDataset:
171
+ def __init__(self, source_uri, max_seq_length):
172
+ self.source_uri = source_uri
173
+ self.s3_client = boto3.client("s3")
174
+ self.max_seq_length = max_seq_length
175
+
176
+ def __iter__(self):
177
+ while True:
178
+ for json_line in open(self.source_uri, transport_params={"client": self.s3_client}):
179
+ data_line = json.loads(json_line.strip())
180
+
181
+ if "response" in data_line and "context" in data_line:
182
+ data = {'guid': None, 'texts': [" ".join(data_line['response'].split(" ")[:self.max_seq_length]), " ".join(data_line['context'].split(" ")[:self.max_seq_length])]}
183
+ record = InputExample(guid=data.get('guid', None), texts=data['texts'])
184
+
185
+ yield record
186
+
187
+
188
+ class HuggingFaceQueryPosDataset:
189
+ def __init__(self, source_uri, max_seq_length):
190
+ self.source_uri = source_uri
191
+ self.s3_client = boto3.client("s3")
192
+ self.max_seq_length = max_seq_length
193
+
194
+ def __iter__(self):
195
+ while True:
196
+ for json_line in open(self.source_uri, transport_params={"client": self.s3_client}):
197
+ data_line = json.loads(json_line.strip())
198
+
199
+ if "query" in data_line and "pos" in data_line:
200
+ for i in range(len(data_line['pos'])):
201
+ data = {'guid': None, 'texts': [" ".join(data_line['query'].split(" ")[:self.max_seq_length]), " ".join(data_line['pos'][i].split(" ")[:self.max_seq_length])]}
202
+ record = InputExample(guid=data.get('guid', None), texts=data['texts'])
203
+
204
+ yield record
205
+
206
+
207
+ class Dataset:
208
+ def __init__(self, source_uri, max_seq_length):
209
+ self.source_uri = source_uri
210
+ self.s3_client = boto3.client("s3")
211
+ self.max_seq_length = max_seq_length
212
+
213
+ def __iter__(self):
214
+ while True:
215
+ for json_line in open(self.source_uri, transport_params={"client": self.s3_client}):
216
+ data_line = json.loads(json_line.strip())
217
+
218
+ if not isinstance(data_line, dict):
219
+ data = {'guid': None, 'texts': data_line}
220
+ for text_idx in range(len(data['texts'])):
221
+ data['texts'][text_idx] = " ".join(data['texts'][text_idx].split(" ")[:self.max_seq_length])
222
+ record = InputExample(guid=data.get('guid', None), texts=data['texts'])
223
+ else:
224
+ for text_idx in range(len(data_line['texts'])):
225
+ data_line['texts'][text_idx] = " ".join(data_line['texts'][text_idx].split(" ")[:self.max_seq_length])
226
+ record = InputExample(guid=data_line.get('guid', None), texts=data_line['texts'])
227
+
228
+ yield record
229
+
230
+ #####
231
+
232
+
233
+ def build_generators(data_records, max_seq_length=512, testing=False):
234
+ """
235
+ This function consumes the data_records dictionary and creates a new dictionary of data generators where each entry is
236
+ of the form {filename: data generator object}.
237
+ """
238
+ if testing:
239
+ # filepaths = [file for file in list(data_records.keys()) if file.startswith('S2ORC') or file.startswith('reddit_')]
240
+ filepaths = [file for file in list(data_records.keys())][:3]
241
+ else:
242
+ filepaths = list(data_records.keys())
243
+ generators = {}
244
+ for filepath in filepaths:
245
+ filepath = filepath.strip()
246
+ source_uri = 's3://lodestone-rnd/data/'+filepath
247
+ if filepath in ['S2ORC_citations_abstracts.json.gz', 'amazon-qa.json.gz'] or 'reddit' in filepath:
248
+ if "title" in filepath:
249
+ generators[f'{filepath.split(".")[0]}'] = iter(RedditTitleBodyDataset(source_uri, max_seq_length))
250
+ elif "reddit" in filepath:
251
+ generators[f'{filepath.split(".")[0]}'] = iter(RedditYearDataset(source_uri, max_seq_length))
252
+ else:
253
+ generators[f'{filepath.split(".")[0]}'] = iter(HuggingFaceQueryPosDataset(source_uri, max_seq_length))
254
+ else:
255
+ generators[f'{filepath.split(".")[0]}'] = iter(Dataset(source_uri, max_seq_length))
256
+
257
+ return generators
258
+
259
+ #####
260
+
261
+
262
+ def produce_data(data_records, num_chunks, generators, batch_size, failed_on=None, first_iter=False, testing=False, temp=-1):
263
+ """
264
+ This function consumes the data_records dictionary, the number of chunks to break the datasets into, the dictionary of
265
+ data generators, and a batch size and returns a MultiDatasetDataloader which can be fed into the .fit method of a
266
+ SentenceTransformer model.
267
+ """
268
+ if testing:
269
+ # filepaths = [file for file in list(data_records.keys()) if file.startswith('S2ORC') or file.startswith('reddit_')]
270
+ filepaths = [file for file in list(data_records.keys())][:3]
271
+ else:
272
+ filepaths = list(data_records.keys())
273
+ datasets = []
274
+ for file_idx, filepath in enumerate(filepaths):
275
+ filepath = filepath.strip()
276
+ dataset = []
277
+
278
+ if failed_on is not None and failed_on != 1 and first_iter:
279
+ for k in range((failed_on-1)*max(1, data_records[filepath]//num_chunks)):
280
+ next(generators[f'{filepath.split(".")[0]}'])
281
+ for m in range(max(1, data_records[filepath]//num_chunks)):
282
+ dataset.append(next(generators[f'{filepath.split(".")[0]}']))
283
+ else:
284
+ for n in range(max(1, data_records[filepath]//num_chunks)):
285
+ dataset.append(next(generators[f'{filepath.split(".")[0]}']))
286
+
287
+ datasets.append(dataset)
288
+ logging.info("{}. {}: {}".format(file_idx+1, filepath, len(dataset)))
289
+
290
+ dataset_lengths_sum = sum(list(map(len, datasets)))
291
+
292
+ batch_size_pairs = batch_size_triplets = batch_size
293
+ # Special data loader to load from multiple datasets
294
+ train_dataloader = MultiDatasetDataLoader(datasets=datasets,
295
+ batch_size_pairs=batch_size_pairs,
296
+ batch_size_triplets=batch_size_triplets,
297
+ dataset_size_temp=temp)
298
+
299
+ return train_dataloader, dataset_lengths_sum
300
+
301
+ #####
302
+
303
+
304
+ def construct_model(model_name, max_seq_length=512):
305
+ """
306
+ This function constructs a SentenceTransformer model from a HuggingFace transformer model name
307
+ or from a local path to a transformer model repository.
308
+ """
309
+ word_embedding_model = models.Transformer(model_name_or_path=model_name,
310
+ max_seq_length=max_seq_length,
311
+ tokenizer_name_or_path='bert-base-uncased',
312
+ trust_remote_code=True,
313
+ model_args={'torch_dtype': torch.bfloat16})
314
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
315
+ norm = models.Normalize()
316
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model, norm], device='cuda')
317
+ model[0].tokenizer.model_max_length = max_seq_length
318
+
319
+ return model
320
+
321
+ #####
322
+
323
+
324
+ # Just some code to print debug information to stdout
325
+ logging.basicConfig(format='%(asctime)s - %(message)s',
326
+ datefmt='%Y-%m-%d %H:%M:%S',
327
+ level=logging.INFO,
328
+ handlers=[LoggingHandler()])
329
+ # /print debug information to stdout
330
+
331
+ #####
332
+
333
+
334
+ # Set Hyperparameters
335
+ model_name = 'mosaic-bert-base-seqlen-2048'
336
+ # model_name = 'hum-lodestone-v1'
337
+ batch_size = 16
338
+ batch_size_pairs = batch_size_triplets = batch_size
339
+ max_seq_length = 2048
340
+ use_amp = False
341
+
342
+ num_cycles = 2
343
+ num_chunks = 50
344
+ num_epochs = 2
345
+ steps_per_epoch = 10000
346
+ # Total training steps = num_cycles * num_chunks * num_epochs * steps_per_epoch = 2 * 50 * 2 * 10,000 = 2,000,000 steps
347
+ warmup_steps = 500
348
+
349
+ testing = False
350
+ temp = -1
351
+
352
+ #####
353
+
354
+
355
+ output_path = 'hum-lodestone-v1'
356
+ logging.info("Output: "+output_path)
357
+
358
+ # Instantiate SentenceTransformer Model
359
+ model = construct_model(model_name=model_name, max_seq_length=max_seq_length)
360
+
361
+ # Load File Names and Record Volumes
362
+ with open('data_records.json') as fIn:
363
+ data_records = json.load(fIn)
364
+
365
+ total_pairs = sum(data_records.values())
366
+
367
+ logging.info("Total Training Pairs: {}".format(total_pairs))
368
+
369
+ # Initialize Data Generators
370
+ generators = build_generators(data_records=data_records,
371
+ max_seq_length=max_seq_length,
372
+ testing=testing)
373
+
374
+ logging.info("Data Generators Initialized")
375
+
376
+ # Define Training Loss Function
377
+ train_loss = losses.MultipleNegativesRankingLoss(model,
378
+ scale=20,
379
+ similarity_fct=util.dot_score)
380
+
381
+ logging.info(print_gpu_utilization())
382
+
383
+ #####
384
+
385
+
386
+ # Configure Training Cycles
387
+ failed_on = None # chunk that the process failed on
388
+ random.seed(42)
389
+ steps = 0
390
+ first_iter = True
391
+ for cycle_num in range(num_cycles):
392
+ logging.info("Starting Cycle {}".format(cycle_num+1))
393
+ for chunk_num in range(num_chunks):
394
+ if failed_on is not None and (chunk_num+1) < failed_on and (cycle_num+1) == 1:
395
+ pass
396
+ else:
397
+ logging.info("Chunk {}/{}".format(chunk_num+1, num_chunks))
398
+ logging.info("Loading {} Datasets".format(len([file for file in list(data_records.keys()) if file.startswith('S2ORC') or file.startswith('reddit_')]) if testing else len(data_records)))
399
+ # t_dataload0 = time.time()
400
+ # Create the training dataloader for the given chunk of data
401
+ train_dataloader, dataset_lengths_sum = produce_data(data_records,
402
+ num_chunks,
403
+ generators,
404
+ batch_size,
405
+ failed_on=failed_on,
406
+ first_iter=first_iter,
407
+ testing=testing,
408
+ temp=temp)
409
+ first_iter = False
410
+ # t_dataload1 = time.time()
411
+ # print(t_dataload1-t_dataload0)
412
+
413
+ logging.info(print_gpu_utilization())
414
+
415
+ # steps_per_epoch = dataset_lengths_sum // batch_size_pairs
416
+
417
+ for epoch_num in range(num_epochs):
418
+ logging.info("Performing Cycle {}, Chunk {}, Epoch {}".format(cycle_num+1, chunk_num+1, epoch_num+1))
419
+ try:
420
+ # t_fit0 = time.time()
421
+ # Train the model
422
+ model.fit(train_objectives=[(train_dataloader, train_loss)],
423
+ evaluator=None,
424
+ epochs=1,
425
+ warmup_steps=warmup_steps,
426
+ steps_per_epoch=steps_per_epoch,
427
+ use_amp=use_amp,
428
+ output_path=output_path)
429
+ # t_fit1 = time.time()
430
+ # print(t_fit1-t_fit0)
431
+
432
+ steps += steps_per_epoch
433
+
434
+ logging.info(print_gpu_utilization())
435
+ logging.info("Succeeded on Cycle {}, Chunk {}, Epoch {}".format(cycle_num+1, chunk_num+1, epoch_num+1))
436
+ logging.info("{} Steps Completed in Total".format(steps))
437
+
438
+ with open('train_logs.txt', 'a') as log:
439
+ log.write("Succeeded on Cycle {}, Chunk {}, Epoch {}: {} Steps Completed in Total\n".format(cycle_num+1, chunk_num+1, epoch_num+1, steps))
440
+
441
+ except:
442
+ logging.info("Failed on Cycle {}, Chunk {}, Epoch {}".format(cycle_num+1, chunk_num+1, epoch_num+1))
443
+
444
+ with open('train_logs.txt', 'a') as log:
445
+ log.write("Failed on Cycle {}, Chunk {}, Epoch {}: {} Steps Completed in Total\n".format(cycle_num+1, chunk_num+1, epoch_num+1, steps))
446
+
447
+ finally:
448
+ warmup_steps = 0
449
+
450
+ # Clear GPU/CUDA memory cache between data chunks
451
+ train_dataloader = None
452
+ model = None
453
+ train_loss = None
454
+
455
+ gc.collect()
456
+ torch.cuda.empty_cache()
457
+
458
+ # Reload the model and reinitialize the loss function
459
+ model = construct_model(model_name='hum-lodestone-v1', max_seq_length=max_seq_length)
460
+
461
+ train_loss = losses.MultipleNegativesRankingLoss(model,
462
+ scale=20,
463
+ similarity_fct=util.dot_score)
464
+
465
+ logging.info(print_gpu_utilization())
bert_layers.py ADDED
@@ -0,0 +1,1072 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
5
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
6
+ # Copyright (c) 2022, Tri Dao.
7
+
8
+ """Implements Mosaic BERT, with an eye towards the Hugging Face API.
9
+
10
+ Mosaic BERT improves performance over Hugging Face BERT through the following:
11
+
12
+ 1. ALiBi. This architectural change removes positional embeddings and instead encodes positional
13
+ information through attention biases based on query-key position distance. It improves the effectiveness
14
+ of training with shorter sequence lengths by enabling extrapolation to longer sequences.
15
+
16
+ 2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer
17
+ to improve overall expressiveness, providing better convergence properties.
18
+
19
+ 3. Flash Attention. The Mosaic BERT's self-attention layer makes use of Flash Attention, which dramatically
20
+ improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that
21
+ supports attention biases, which allows us to use Flash Attention with ALiBi.
22
+
23
+ 4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT
24
+ implementations waste computation on padded tokens. Mosaic BERT internally unpads to reduce unnecessary computation
25
+ and improve speed. It does this without changing how the user interfaces with the model, thereby
26
+ preserving the simple API of standard implementations.
27
+
28
+
29
+ Currently, Mosaic BERT is available for masked language modeling :class:`BertForMaskedLM` and sequence
30
+ classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases.
31
+
32
+ See :file:`./mosaic_bert.py` for utilities to simplify working with Mosaic BERT in Composer, and for example usage
33
+ of the core Mosaic BERT classes.
34
+ """
35
+
36
+ import copy
37
+ import logging
38
+ import math
39
+ import warnings
40
+ from typing import List, Optional, Tuple, Union
41
+
42
+ import torch
43
+ import torch.nn as nn
44
+ from einops import rearrange
45
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
46
+ from transformers.activations import ACT2FN
47
+ from transformers.modeling_outputs import (MaskedLMOutput,
48
+ SequenceClassifierOutput)
49
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
50
+
51
+ from .bert_padding import (index_first_axis,
52
+ index_put_first_axis, pad_input,
53
+ unpad_input, unpad_input_only)
54
+
55
+ try:
56
+ from .flash_attn_triton import flash_attn_qkvpacked_func
57
+ except ImportError as e:
58
+ flash_attn_qkvpacked_func = None
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ class BertEmbeddings(nn.Module):
64
+ """Construct the embeddings for words, ignoring position.
65
+
66
+ There are no positional embeddings since we use ALiBi and token_type
67
+ embeddings.
68
+
69
+ This module is modeled after the Hugging Face BERT's
70
+ :class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is
71
+ modified as part of Mosaic BERT's ALiBi implementation. The key change is
72
+ that position embeddings are removed. Position information instead comes
73
+ from attention biases that scale linearly with the position distance
74
+ between query and key tokens.
75
+
76
+ This module ignores the `position_ids` input to the `forward` method.
77
+ """
78
+
79
+ def __init__(self, config):
80
+ super().__init__()
81
+ self.word_embeddings = nn.Embedding(config.vocab_size,
82
+ config.hidden_size,
83
+ padding_idx=config.pad_token_id)
84
+ # ALiBi doesn't use position embeddings
85
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
86
+ config.hidden_size)
87
+
88
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model
89
+ # variable name and be able to load any TensorFlow checkpoint file
90
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
91
+ eps=config.layer_norm_eps)
92
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
93
+ self.register_buffer('token_type_ids',
94
+ torch.zeros(config.max_position_embeddings,
95
+ dtype=torch.long),
96
+ persistent=False)
97
+
98
+ def forward(
99
+ self,
100
+ input_ids: Optional[torch.LongTensor] = None,
101
+ token_type_ids: Optional[torch.LongTensor] = None,
102
+ position_ids: Optional[torch.LongTensor] = None,
103
+ inputs_embeds: Optional[torch.FloatTensor] = None,
104
+ past_key_values_length: int = 0,
105
+ ) -> torch.Tensor:
106
+ if (input_ids is not None) == (inputs_embeds is not None):
107
+ raise ValueError('Must specify either input_ids or input_embeds!')
108
+ if input_ids is not None:
109
+ input_shape = input_ids.size()
110
+ else:
111
+ assert inputs_embeds is not None # just for type checking
112
+ input_shape = inputs_embeds.size()[:-1]
113
+
114
+ seq_length = input_shape[1]
115
+
116
+ if position_ids is None:
117
+ # great! ALiBi
118
+ pass
119
+
120
+ # Setting the token_type_ids to the registered buffer in constructor
121
+ # where it is all zeros, which usually occurs when it's auto-generated;
122
+ # registered buffer helps users when tracing the model without passing
123
+ # token_type_ids, solves issue #5664
124
+ if token_type_ids is None:
125
+ if hasattr(self, 'token_type_ids'):
126
+ assert isinstance(self.token_type_ids, torch.LongTensor)
127
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
128
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
129
+ input_shape[0], seq_length)
130
+ token_type_ids = buffered_token_type_ids_expanded # type: ignore
131
+ else:
132
+ token_type_ids = torch.zeros(input_shape, # type: ignore
133
+ dtype=torch.long,
134
+ device=self.word_embeddings.device) # type: ignore # yapf: disable
135
+
136
+ if inputs_embeds is None:
137
+ inputs_embeds = self.word_embeddings(input_ids)
138
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
139
+
140
+ embeddings = inputs_embeds + token_type_embeddings
141
+ # no position embeddings! ALiBi
142
+ embeddings = self.LayerNorm(embeddings)
143
+ embeddings = self.dropout(embeddings)
144
+ return embeddings
145
+
146
+
147
+ class BertUnpadSelfAttention(nn.Module):
148
+ """Performs multi-headed self attention on a batch of unpadded sequences.
149
+
150
+ If Triton is installed, this module uses Flash Attention to greatly improve throughput.
151
+ The Flash Attention implementation used in Mosaic BERT supports arbitrary attention biases (which
152
+ we use to implement ALiBi), but does not support attention dropout. If either Triton is not installed
153
+ or `config.attention_probs_dropout_prob > 0`, the implementation will default to a
154
+ math-equivalent pytorch version, which is much slower.
155
+
156
+ See `forward` method for additional detail.
157
+ """
158
+
159
+ def __init__(self, config):
160
+ super().__init__()
161
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
162
+ config, 'embedding_size'):
163
+ raise ValueError(
164
+ f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
165
+ f'heads ({config.num_attention_heads})')
166
+
167
+ self.num_attention_heads = config.num_attention_heads
168
+ self.attention_head_size = int(config.hidden_size /
169
+ config.num_attention_heads)
170
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
171
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
172
+ self.p_dropout = config.attention_probs_dropout_prob
173
+ self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
174
+
175
+ # Warn if defaulting to pytorch because of import issues
176
+ if flash_attn_qkvpacked_func is None:
177
+ warnings.warn(
178
+ 'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
179
+ )
180
+
181
+ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
182
+ max_seqlen_in_batch: int, indices: torch.Tensor,
183
+ attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
184
+ """Perform self-attention.
185
+
186
+ If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
187
+ implementation of self-attention.
188
+
189
+ The arguments are unpadded, and our implementations of attention require padded arguments,
190
+ so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers.
191
+ The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute.
192
+ It is possible to write an unpadded implementation of attention (in Triton and PyTorch), which we will eventually do.
193
+
194
+ Args:
195
+ hidden_states: (total_nnz, dim)
196
+ cu_seqlens: (batch + 1,)
197
+ max_seqlen_in_batch: int
198
+ indices: (total_nnz,)
199
+ attn_mask: (batch, max_seqlen_in_batch)
200
+ bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
201
+
202
+ Returns:
203
+ attention: (total_nnz, dim)
204
+ """
205
+ qkv = self.Wqkv(hidden_states)
206
+ qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1,
207
+ max_seqlen_in_batch) # batch, max_seqlen_in_batch, thd
208
+ qkv = rearrange(qkv,
209
+ 'b s (t h d) -> b s t h d',
210
+ t=3,
211
+ h=self.num_attention_heads)
212
+ if self.p_dropout or flash_attn_qkvpacked_func is None:
213
+ # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
214
+ q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
215
+ k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
216
+ v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
217
+ attention_scores = torch.matmul(q, k) / math.sqrt(
218
+ self.attention_head_size)
219
+ attention_scores = attention_scores + bias
220
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
221
+ attention_probs = self.dropout(attention_probs)
222
+ attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
223
+ 3) # b s h d
224
+ else:
225
+ # Triton implementation only supports 0 attention dropout
226
+ convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
227
+ if convert_dtype:
228
+ # Triton implementation only supports fp16 and bf16
229
+ orig_dtype = qkv.dtype
230
+ qkv = qkv.to(torch.float16)
231
+ bias_dtype = bias.dtype
232
+ bias = bias.to(torch.float16)
233
+ attention = flash_attn_qkvpacked_func(qkv, bias)
234
+ attention = attention.to(orig_dtype)
235
+ bias = bias.to(bias_dtype)
236
+ else:
237
+ attention = flash_attn_qkvpacked_func(qkv, bias)
238
+
239
+ # attn_mask is 1 for attend and 0 for don't
240
+ attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
241
+ return rearrange(attention, 'nnz h d -> nnz (h d)')
242
+
243
+
244
+ # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
245
+ class BertSelfOutput(nn.Module):
246
+ """Computes the output of the attention layer.
247
+
248
+ This module is modeled after the Hugging Face BERT's
249
+ :class:`~transformers.model.bert.modeling_bert.BertSelfOutput`.
250
+ The implementation is identical. Rather than use the original module
251
+ directly, we re-implement it here so that Mosaic BERT's modules will not
252
+ be affected by any Composer surgery algorithm that modifies Hugging Face
253
+ BERT modules.
254
+ """
255
+
256
+ def __init__(self, config):
257
+ super().__init__()
258
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
259
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
260
+ eps=config.layer_norm_eps)
261
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
262
+
263
+ def forward(self, hidden_states: torch.Tensor,
264
+ input_tensor: torch.Tensor) -> torch.Tensor:
265
+ hidden_states = self.dense(hidden_states)
266
+ hidden_states = self.dropout(hidden_states)
267
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
268
+ return hidden_states
269
+
270
+
271
+ class BertUnpadAttention(nn.Module):
272
+ """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
273
+
274
+ def __init__(self, config):
275
+ super().__init__()
276
+ self.self = BertUnpadSelfAttention(config)
277
+ self.output = BertSelfOutput(config)
278
+
279
+ def forward(
280
+ self,
281
+ input_tensor: torch.Tensor,
282
+ cu_seqlens: torch.Tensor,
283
+ max_s: int,
284
+ subset_idx: Optional[torch.Tensor] = None,
285
+ indices: Optional[torch.Tensor] = None,
286
+ attn_mask: Optional[torch.Tensor] = None,
287
+ bias: Optional[torch.Tensor] = None,
288
+ ) -> torch.Tensor:
289
+ """Forward pass for scaled self-attention without padding.
290
+
291
+ Arguments:
292
+ input_tensor: (total_nnz, dim)
293
+ cu_seqlens: (batch + 1,)
294
+ max_s: int
295
+ subset_idx: () set of indices whose values we care about at the end of the layer
296
+ (e.g., the masked tokens, if this is the final layer).
297
+ indices: None or (total_nnz,)
298
+ attn_mask: None or (batch, max_seqlen_in_batch)
299
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
300
+ """
301
+ self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
302
+ attn_mask, bias)
303
+ if subset_idx is not None:
304
+ return self.output(index_first_axis(self_output, subset_idx),
305
+ index_first_axis(input_tensor, subset_idx))
306
+ else:
307
+ return self.output(self_output, input_tensor)
308
+
309
+
310
+ class BertGatedLinearUnitMLP(nn.Module):
311
+ """Applies the FFN at the end of each Mosaic BERT layer.
312
+
313
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
314
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
315
+ introduces Gated Linear Units.
316
+
317
+ Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
318
+ standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
319
+ `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
320
+ with the `config.intermediate_size=3072`.
321
+ However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
322
+ parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
323
+ """
324
+
325
+ def __init__(self, config):
326
+ super().__init__()
327
+ self.config = config
328
+ self.gated_layers = nn.Linear(config.hidden_size,
329
+ config.intermediate_size * 2,
330
+ bias=False)
331
+ self.act = nn.GELU(approximate='none')
332
+ self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
333
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
334
+ self.layernorm = nn.LayerNorm(config.hidden_size,
335
+ eps=config.layer_norm_eps)
336
+
337
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
338
+ """Compute new hidden states from current hidden states.
339
+
340
+ Args:
341
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
342
+ the attention layer [nnz, dim].
343
+ """
344
+ residual_connection = hidden_states
345
+ # compute the activation
346
+ hidden_states = self.gated_layers(hidden_states)
347
+ gated = hidden_states[:, :self.config.intermediate_size]
348
+ non_gated = hidden_states[:, self.config.intermediate_size:]
349
+ hidden_states = self.act(gated) * non_gated
350
+ hidden_states = self.dropout(hidden_states)
351
+ # multiply by the second matrix
352
+ hidden_states = self.wo(hidden_states)
353
+ # add the residual connection and post-LN
354
+ hidden_states = self.layernorm(hidden_states + residual_connection)
355
+ return hidden_states
356
+
357
+
358
+ class BertLayer(nn.Module):
359
+ """Composes the Mosaic BERT attention and FFN blocks into a single layer."""
360
+
361
+ def __init__(self, config):
362
+ super(BertLayer, self).__init__()
363
+ self.attention = BertUnpadAttention(config)
364
+ self.mlp = BertGatedLinearUnitMLP(config)
365
+
366
+ def forward(
367
+ self,
368
+ hidden_states: torch.Tensor,
369
+ cu_seqlens: torch.Tensor,
370
+ seqlen: int,
371
+ subset_idx: Optional[torch.Tensor] = None,
372
+ indices: Optional[torch.Tensor] = None,
373
+ attn_mask: Optional[torch.Tensor] = None,
374
+ bias: Optional[torch.Tensor] = None,
375
+ ) -> torch.Tensor:
376
+ """Forward pass for a BERT layer, including both attention and MLP.
377
+
378
+ Args:
379
+ hidden_states: (total_nnz, dim)
380
+ cu_seqlens: (batch + 1,)
381
+ seqlen: int
382
+ subset_idx: () set of indices whose values we care about at the end of the layer
383
+ (e.g., the masked tokens, if this is the final layer).
384
+ indices: None or (total_nnz,)
385
+ attn_mask: None or (batch, max_seqlen_in_batch)
386
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
387
+ """
388
+ attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
389
+ subset_idx, indices, attn_mask, bias)
390
+ layer_output = self.mlp(attention_output)
391
+ return layer_output
392
+
393
+
394
+ class BertEncoder(nn.Module):
395
+ """A stack of BERT layers providing the backbone of Mosaic BERT.
396
+
397
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`,
398
+ but with substantial modifications to implement unpadding and ALiBi.
399
+
400
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
401
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
402
+ """
403
+
404
+ def __init__(self, config):
405
+ super().__init__()
406
+ layer = BertLayer(config)
407
+ self.layer = nn.ModuleList(
408
+ [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
409
+
410
+ self.num_attention_heads = config.num_attention_heads
411
+
412
+ # The alibi mask will be dynamically expanded if it is too small for
413
+ # the input the model receives. But it generally helps to initialize it
414
+ # to a reasonably large size to help pre-allocate CUDA memory.
415
+ # The default `alibi_starting_size` is 512.
416
+ self._current_alibi_size = int(config.alibi_starting_size)
417
+ self.alibi = torch.zeros(
418
+ (1, self.num_attention_heads, self._current_alibi_size,
419
+ self._current_alibi_size))
420
+ self.rebuild_alibi_tensor(size=config.alibi_starting_size)
421
+
422
+ def rebuild_alibi_tensor(self,
423
+ size: int,
424
+ device: Optional[Union[torch.device, str]] = None):
425
+ # Alibi
426
+ # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
427
+ # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
428
+ # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
429
+ # will be applied, it is necessary to construct the diagonal mask.
430
+ n_heads = self.num_attention_heads
431
+
432
+ def _get_alibi_head_slopes(n_heads: int) -> List[float]:
433
+
434
+ def get_slopes_power_of_2(n_heads: int) -> List[float]:
435
+ start = (2**(-2**-(math.log2(n_heads) - 3)))
436
+ ratio = start
437
+ return [start * ratio**i for i in range(n_heads)]
438
+
439
+ # In the paper, they only train models that have 2^a heads for some a. This function
440
+ # has some good properties that only occur when the input is a power of 2. To
441
+ # maintain that even when the number of heads is not a power of 2, we use a
442
+ # workaround.
443
+ if math.log2(n_heads).is_integer():
444
+ return get_slopes_power_of_2(n_heads)
445
+
446
+ closest_power_of_2 = 2**math.floor(math.log2(n_heads))
447
+ slopes_a = get_slopes_power_of_2(closest_power_of_2)
448
+ slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
449
+ slopes_b = slopes_b[0::2][:n_heads - closest_power_of_2]
450
+ return slopes_a + slopes_b
451
+
452
+ context_position = torch.arange(size, device=device)[:, None]
453
+ memory_position = torch.arange(size, device=device)[None, :]
454
+ relative_position = torch.abs(memory_position - context_position)
455
+ # [n_heads, max_token_length, max_token_length]
456
+ relative_position = relative_position.unsqueeze(0).expand(
457
+ n_heads, -1, -1)
458
+ slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
459
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
460
+ # [1, n_heads, max_token_length, max_token_length]
461
+ alibi = alibi.unsqueeze(0)
462
+ assert alibi.shape == torch.Size([1, n_heads, size, size])
463
+
464
+ self._current_alibi_size = size
465
+ self.alibi = alibi
466
+
467
+ def forward(
468
+ self,
469
+ hidden_states: torch.Tensor,
470
+ attention_mask: torch.Tensor,
471
+ output_all_encoded_layers: Optional[bool] = True,
472
+ subset_mask: Optional[torch.Tensor] = None,
473
+ ) -> List[torch.Tensor]:
474
+
475
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
476
+ extended_attention_mask = extended_attention_mask.to(
477
+ dtype=next(self.parameters()).dtype) # fp16 compatibility
478
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
479
+
480
+ attention_mask_bool = attention_mask.bool()
481
+ batch, seqlen = hidden_states.shape[:2]
482
+ # Unpad inputs and mask. It will remove tokens that are padded.
483
+ # Assume ntokens is total number of tokens (padded and non-padded)
484
+ # and ntokens_unpad is total number of non-padded tokens.
485
+ # Then unpadding performs the following compression of the inputs:
486
+ # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
487
+ hidden_states, indices, cu_seqlens, _ = unpad_input(
488
+ hidden_states, attention_mask_bool)
489
+
490
+ # Add alibi matrix to extended_attention_mask
491
+ if self._current_alibi_size < seqlen:
492
+ # Rebuild the alibi tensor when needed
493
+ warnings.warn(
494
+ f'Increasing alibi size from {self._current_alibi_size} to {seqlen}'
495
+ )
496
+ self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
497
+ elif self.alibi.device != hidden_states.device:
498
+ # Device catch-up
499
+ self.alibi = self.alibi.to(hidden_states.device)
500
+ alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
501
+ attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
502
+ alibi_attn_mask = attn_bias + alibi_bias
503
+
504
+ all_encoder_layers = []
505
+ if subset_mask is None:
506
+ for layer_module in self.layer:
507
+ hidden_states = layer_module(hidden_states,
508
+ cu_seqlens,
509
+ seqlen,
510
+ None,
511
+ indices,
512
+ attn_mask=attention_mask,
513
+ bias=alibi_attn_mask)
514
+ if output_all_encoded_layers:
515
+ all_encoder_layers.append(hidden_states)
516
+ # Pad inputs and mask. It will insert back zero-padded tokens.
517
+ # Assume ntokens is total number of tokens (padded and non-padded)
518
+ # and ntokens_unpad is total number of non-padded tokens.
519
+ # Then padding performs the following de-compression:
520
+ # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
521
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
522
+ else:
523
+ for i in range(len(self.layer) - 1):
524
+ layer_module = self.layer[i]
525
+ hidden_states = layer_module(hidden_states,
526
+ cu_seqlens,
527
+ seqlen,
528
+ None,
529
+ indices,
530
+ attn_mask=attention_mask,
531
+ bias=alibi_attn_mask)
532
+ if output_all_encoded_layers:
533
+ all_encoder_layers.append(hidden_states)
534
+ subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
535
+ as_tuple=False).flatten()
536
+ hidden_states = self.layer[-1](hidden_states,
537
+ cu_seqlens,
538
+ seqlen,
539
+ subset_idx=subset_idx,
540
+ indices=indices,
541
+ attn_mask=attention_mask,
542
+ bias=alibi_attn_mask)
543
+
544
+ if not output_all_encoded_layers:
545
+ all_encoder_layers.append(hidden_states)
546
+ return all_encoder_layers
547
+
548
+
549
+ class BertPooler(nn.Module):
550
+
551
+ def __init__(self, config):
552
+ super(BertPooler, self).__init__()
553
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
554
+ self.activation = nn.Tanh()
555
+
556
+ def forward(self,
557
+ hidden_states: torch.Tensor,
558
+ pool: Optional[bool] = True) -> torch.Tensor:
559
+ # We "pool" the model by simply taking the hidden state corresponding
560
+ # to the first token.
561
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
562
+ pooled_output = self.dense(first_token_tensor)
563
+ pooled_output = self.activation(pooled_output)
564
+ return pooled_output
565
+
566
+
567
+ class BertPredictionHeadTransform(nn.Module):
568
+
569
+ def __init__(self, config):
570
+ super().__init__()
571
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
572
+ if isinstance(config.hidden_act, str):
573
+ self.transform_act_fn = ACT2FN[config.hidden_act]
574
+ else:
575
+ self.transform_act_fn = config.hidden_act
576
+ self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
577
+
578
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
579
+ hidden_states = self.dense(hidden_states)
580
+ hidden_states = self.transform_act_fn(hidden_states)
581
+ hidden_states = self.LayerNorm(hidden_states)
582
+ return hidden_states
583
+
584
+
585
+ class BertModel(BertPreTrainedModel):
586
+ """Overall BERT model.
587
+
588
+ Args:
589
+ config: a BertConfig class instance with the configuration to build a new model
590
+
591
+ Inputs:
592
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
593
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
594
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
595
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
596
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
597
+ a `sentence B` token (see BERT paper for more details).
598
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
599
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
600
+ input sequence length in the current batch. It's the mask that we typically use for attention when
601
+ a batch has varying length sentences.
602
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
603
+
604
+ Outputs: Tuple of (encoded_layers, pooled_output)
605
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
606
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
607
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
608
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
609
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
610
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
611
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
612
+ classifier pretrained on top of the hidden state associated to the first character of the
613
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
614
+
615
+ Example usage:
616
+ ```python
617
+ # Already been converted into WordPiece token ids
618
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
619
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
620
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
621
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
622
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
623
+ model = BertModel(config=config)
624
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
625
+ ```
626
+ """
627
+
628
+ def __init__(self, config, add_pooling_layer=True):
629
+ super(BertModel, self).__init__(config)
630
+ self.embeddings = BertEmbeddings(config)
631
+ self.encoder = BertEncoder(config)
632
+ self.pooler = BertPooler(config) if add_pooling_layer else None
633
+ self.post_init()
634
+
635
+ def get_input_embeddings(self):
636
+ return self.embeddings.word_embeddings
637
+
638
+ def set_input_embeddings(self, value):
639
+ self.embeddings.word_embeddings = value
640
+
641
+ def forward(
642
+ self,
643
+ input_ids: torch.Tensor,
644
+ token_type_ids: Optional[torch.Tensor] = None,
645
+ attention_mask: Optional[torch.Tensor] = None,
646
+ position_ids: Optional[torch.Tensor] = None,
647
+ output_all_encoded_layers: Optional[bool] = False,
648
+ masked_tokens_mask: Optional[torch.Tensor] = None,
649
+ **kwargs
650
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
651
+ if attention_mask is None:
652
+ attention_mask = torch.ones_like(input_ids)
653
+ if token_type_ids is None:
654
+ token_type_ids = torch.zeros_like(input_ids)
655
+
656
+ embedding_output = self.embeddings(input_ids, token_type_ids,
657
+ position_ids)
658
+
659
+ subset_mask = []
660
+ first_col_mask = []
661
+
662
+ if masked_tokens_mask is None:
663
+ subset_mask = None
664
+ else:
665
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
666
+ first_col_mask[:, 0] = True
667
+ subset_mask = masked_tokens_mask | first_col_mask
668
+
669
+ encoder_outputs = self.encoder(
670
+ embedding_output,
671
+ attention_mask,
672
+ output_all_encoded_layers=output_all_encoded_layers,
673
+ subset_mask=subset_mask)
674
+
675
+ if masked_tokens_mask is None:
676
+ sequence_output = encoder_outputs[-1]
677
+ pooled_output = self.pooler(
678
+ sequence_output) if self.pooler is not None else None
679
+ else:
680
+ # TD [2022-03-01]: the indexing here is very tricky.
681
+ attention_mask_bool = attention_mask.bool()
682
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
683
+ sequence_output = encoder_outputs[-1][
684
+ masked_tokens_mask[attention_mask_bool][subset_idx]]
685
+ if self.pooler is not None:
686
+ pool_input = encoder_outputs[-1][
687
+ first_col_mask[attention_mask_bool][subset_idx]]
688
+ pooled_output = self.pooler(pool_input, pool=False)
689
+ else:
690
+ pooled_output = None
691
+
692
+ if not output_all_encoded_layers:
693
+ encoder_outputs = sequence_output
694
+
695
+ if self.pooler is not None:
696
+ return encoder_outputs, pooled_output
697
+
698
+ return encoder_outputs, None
699
+
700
+
701
+ ###################
702
+ # Bert Heads
703
+ ###################
704
+ class BertLMPredictionHead(nn.Module):
705
+
706
+ def __init__(self, config, bert_model_embedding_weights):
707
+ super().__init__()
708
+ self.transform = BertPredictionHeadTransform(config)
709
+ # The output weights are the same as the input embeddings, but there is
710
+ # an output-only bias for each token.
711
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
712
+ bert_model_embedding_weights.size(0))
713
+ self.decoder.weight = bert_model_embedding_weights
714
+
715
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
716
+ hidden_states = self.transform(hidden_states)
717
+ hidden_states = self.decoder(hidden_states)
718
+ return hidden_states
719
+
720
+
721
+ class BertOnlyMLMHead(nn.Module):
722
+
723
+ def __init__(self, config, bert_model_embedding_weights):
724
+ super().__init__()
725
+ self.predictions = BertLMPredictionHead(config,
726
+ bert_model_embedding_weights)
727
+
728
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
729
+ prediction_scores = self.predictions(sequence_output)
730
+ return prediction_scores
731
+
732
+
733
+ class BertOnlyNSPHead(nn.Module):
734
+
735
+ def __init__(self, config):
736
+ super().__init__()
737
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
738
+
739
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
740
+ seq_relationship_score = self.seq_relationship(pooled_output)
741
+ return seq_relationship_score
742
+
743
+
744
+ #####################
745
+ # Various Bert models
746
+ #####################
747
+
748
+
749
+ class BertForPreTraining(BertPreTrainedModel):
750
+ #TBD: Coming in Future Commit
751
+ pass
752
+
753
+
754
+ class BertLMHeadModel(BertPreTrainedModel):
755
+ #TBD: Coming in Future Commit
756
+ pass
757
+
758
+
759
+ class BertForMaskedLM(BertPreTrainedModel):
760
+
761
+ def __init__(self, config):
762
+ super().__init__(config)
763
+
764
+ if config.is_decoder:
765
+ warnings.warn(
766
+ 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
767
+ 'bi-directional self-attention.')
768
+
769
+ self.bert = BertModel(config, add_pooling_layer=False)
770
+ self.cls = BertOnlyMLMHead(config,
771
+ self.bert.embeddings.word_embeddings.weight)
772
+
773
+ # Initialize weights and apply final processing
774
+ self.post_init()
775
+
776
+ @classmethod
777
+ def from_composer(cls,
778
+ pretrained_checkpoint,
779
+ state_dict=None,
780
+ cache_dir=None,
781
+ from_tf=False,
782
+ config=None,
783
+ *inputs,
784
+ **kwargs):
785
+ """Load from pre-trained."""
786
+ model = cls(config, *inputs, **kwargs)
787
+ if from_tf:
788
+ raise ValueError(
789
+ 'Mosaic BERT does not support loading TensorFlow weights.')
790
+
791
+ state_dict = torch.load(pretrained_checkpoint)
792
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
793
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
794
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
795
+ strict=False)
796
+
797
+ if len(missing_keys) > 0:
798
+ logger.warning(
799
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
800
+ )
801
+ if len(unexpected_keys) > 0:
802
+ logger.warning(
803
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
804
+ )
805
+
806
+ return model
807
+
808
+ def get_output_embeddings(self):
809
+ return self.cls.predictions.decoder
810
+
811
+ def set_output_embeddings(self, new_embeddings):
812
+ self.cls.predictions.decoder = new_embeddings
813
+
814
+ def forward(
815
+ self,
816
+ input_ids: Optional[torch.Tensor] = None,
817
+ attention_mask: Optional[torch.Tensor] = None,
818
+ token_type_ids: Optional[torch.Tensor] = None,
819
+ position_ids: Optional[torch.Tensor] = None,
820
+ head_mask: Optional[torch.Tensor] = None,
821
+ inputs_embeds: Optional[torch.Tensor] = None,
822
+ encoder_hidden_states: Optional[torch.Tensor] = None,
823
+ encoder_attention_mask: Optional[torch.Tensor] = None,
824
+ labels: Optional[torch.Tensor] = None,
825
+ output_attentions: Optional[bool] = None,
826
+ output_hidden_states: Optional[bool] = None,
827
+ return_dict: Optional[bool] = None,
828
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
829
+ # labels should be a `torch.LongTensor` of shape
830
+ # `(batch_size, sequence_length)`. These are used for computing the
831
+ # masked language modeling loss.
832
+ #
833
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
834
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
835
+ # (masked), the loss is only computed for the tokens with labels in `[0,
836
+ # ..., config.vocab_size]`
837
+ #
838
+ # Prediction scores are only computed for masked tokens and the (bs,
839
+ # seqlen) dimensions are flattened
840
+ if (input_ids is not None) == (inputs_embeds is not None):
841
+ raise ValueError('Must specify either input_ids or input_embeds!')
842
+
843
+ if labels is None:
844
+ masked_tokens_mask = None
845
+ else:
846
+ masked_tokens_mask = labels > 0
847
+
848
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
849
+
850
+ outputs = self.bert(
851
+ input_ids,
852
+ attention_mask=attention_mask,
853
+ token_type_ids=token_type_ids,
854
+ position_ids=position_ids,
855
+ head_mask=head_mask,
856
+ inputs_embeds=inputs_embeds,
857
+ encoder_hidden_states=encoder_hidden_states,
858
+ encoder_attention_mask=encoder_attention_mask,
859
+ output_attentions=output_attentions,
860
+ output_hidden_states=output_hidden_states,
861
+ return_dict=return_dict,
862
+ masked_tokens_mask=masked_tokens_mask,
863
+ )
864
+
865
+ sequence_output = outputs[0]
866
+ prediction_scores = self.cls(sequence_output)
867
+
868
+ loss = None
869
+ if labels is not None:
870
+ # Compute loss
871
+ loss_fct = nn.CrossEntropyLoss()
872
+ masked_token_idx = torch.nonzero(labels.flatten() > 0,
873
+ as_tuple=False).flatten()
874
+ loss = loss_fct(prediction_scores,
875
+ labels.flatten()[masked_token_idx])
876
+
877
+ assert input_ids is not None, 'Coding error; please open an issue'
878
+ batch, seqlen = input_ids.shape[:2]
879
+ prediction_scores = rearrange(index_put_first_axis(
880
+ prediction_scores, masked_token_idx, batch * seqlen),
881
+ '(b s) d -> b s d',
882
+ b=batch)
883
+
884
+ if not return_dict:
885
+ output = (prediction_scores,) + outputs[2:]
886
+ return ((loss,) + output) if loss is not None else output
887
+
888
+ return MaskedLMOutput(
889
+ loss=loss,
890
+ logits=prediction_scores,
891
+ hidden_states=None,
892
+ attentions=None,
893
+ )
894
+
895
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
896
+ attention_mask: torch.Tensor,
897
+ **model_kwargs):
898
+ input_shape = input_ids.shape
899
+ effective_batch_size = input_shape[0]
900
+
901
+ # add a dummy token
902
+ if self.config.pad_token_id is None:
903
+ raise ValueError('The PAD token should be defined for generation')
904
+
905
+ attention_mask = torch.cat([
906
+ attention_mask,
907
+ attention_mask.new_zeros((attention_mask.shape[0], 1))
908
+ ],
909
+ dim=-1)
910
+ dummy_token = torch.full((effective_batch_size, 1),
911
+ self.config.pad_token_id,
912
+ dtype=torch.long,
913
+ device=input_ids.device)
914
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
915
+
916
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
917
+
918
+
919
+ class BertForNextSentencePrediction(BertPreTrainedModel):
920
+ #TBD: Push in future commit
921
+ pass
922
+
923
+
924
+ class BertForSequenceClassification(BertPreTrainedModel):
925
+ """Bert Model transformer with a sequence classification/regression head.
926
+
927
+ This head is just a linear layer on top of the pooled output. Used for,
928
+ e.g., GLUE tasks.
929
+ """
930
+
931
+ def __init__(self, config):
932
+ super().__init__(config)
933
+ self.num_labels = config.num_labels
934
+ self.config = config
935
+
936
+ self.bert = BertModel(config)
937
+ classifier_dropout = (config.classifier_dropout
938
+ if config.classifier_dropout is not None else
939
+ config.hidden_dropout_prob)
940
+ self.dropout = nn.Dropout(classifier_dropout)
941
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
942
+
943
+ # Initialize weights and apply final processing
944
+ self.post_init()
945
+
946
+ @classmethod
947
+ def from_composer(cls,
948
+ pretrained_checkpoint,
949
+ state_dict=None,
950
+ cache_dir=None,
951
+ from_tf=False,
952
+ config=None,
953
+ *inputs,
954
+ **kwargs):
955
+ """Load from pre-trained."""
956
+ model = cls(config, *inputs, **kwargs)
957
+ if from_tf:
958
+ raise ValueError(
959
+ 'Mosaic BERT does not support loading TensorFlow weights.')
960
+
961
+ state_dict = torch.load(pretrained_checkpoint)
962
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
963
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
964
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
965
+ strict=False)
966
+
967
+ if len(missing_keys) > 0:
968
+ logger.warning(
969
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
970
+ )
971
+ if len(unexpected_keys) > 0:
972
+ logger.warning(
973
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
974
+ )
975
+
976
+ return model
977
+
978
+ def forward(
979
+ self,
980
+ input_ids: Optional[torch.Tensor] = None,
981
+ attention_mask: Optional[torch.Tensor] = None,
982
+ token_type_ids: Optional[torch.Tensor] = None,
983
+ position_ids: Optional[torch.Tensor] = None,
984
+ head_mask: Optional[torch.Tensor] = None,
985
+ inputs_embeds: Optional[torch.Tensor] = None,
986
+ labels: Optional[torch.Tensor] = None,
987
+ output_attentions: Optional[bool] = None,
988
+ output_hidden_states: Optional[bool] = None,
989
+ return_dict: Optional[bool] = None,
990
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
991
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
992
+ # Labels for computing the sequence classification/regression loss.
993
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
994
+ # If `config.num_labels == 1` a regression loss is computed
995
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
996
+ # is computed (cross-entropy).
997
+
998
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
999
+
1000
+ outputs = self.bert(
1001
+ input_ids,
1002
+ attention_mask=attention_mask,
1003
+ token_type_ids=token_type_ids,
1004
+ position_ids=position_ids,
1005
+ head_mask=head_mask,
1006
+ inputs_embeds=inputs_embeds,
1007
+ output_attentions=output_attentions,
1008
+ output_hidden_states=output_hidden_states,
1009
+ return_dict=return_dict,
1010
+ )
1011
+
1012
+ pooled_output = outputs[1]
1013
+
1014
+ pooled_output = self.dropout(pooled_output)
1015
+ logits = self.classifier(pooled_output)
1016
+
1017
+ loss = None
1018
+ if labels is not None:
1019
+ # Compute loss
1020
+ if self.config.problem_type is None:
1021
+ if self.num_labels == 1:
1022
+ self.config.problem_type = 'regression'
1023
+ elif self.num_labels > 1 and (labels.dtype == torch.long or
1024
+ labels.dtype == torch.int):
1025
+ self.config.problem_type = 'single_label_classification'
1026
+ else:
1027
+ self.config.problem_type = 'multi_label_classification'
1028
+
1029
+ if self.config.problem_type == 'regression':
1030
+ loss_fct = nn.MSELoss()
1031
+ if self.num_labels == 1:
1032
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1033
+ else:
1034
+ loss = loss_fct(logits, labels)
1035
+ elif self.config.problem_type == 'single_label_classification':
1036
+ loss_fct = nn.CrossEntropyLoss()
1037
+ loss = loss_fct(logits.view(-1, self.num_labels),
1038
+ labels.view(-1))
1039
+ elif self.config.problem_type == 'multi_label_classification':
1040
+ loss_fct = nn.BCEWithLogitsLoss()
1041
+ loss = loss_fct(logits, labels)
1042
+
1043
+ if not return_dict:
1044
+ output = (logits,) + outputs[2:]
1045
+ return ((loss,) + output) if loss is not None else output
1046
+
1047
+ return SequenceClassifierOutput(
1048
+ loss=loss,
1049
+ logits=logits,
1050
+ hidden_states=None,
1051
+ attentions=None,
1052
+ )
1053
+
1054
+
1055
+ class BertForMultipleChoice(BertPreTrainedModel):
1056
+ #TBD: Push in future commit
1057
+ pass
1058
+
1059
+
1060
+ class BertForTokenClassification(BertPreTrainedModel):
1061
+ #TBD: Push in future commit
1062
+ pass
1063
+
1064
+
1065
+ class BertForQuestionAnswering(BertPreTrainedModel):
1066
+ """Bert Model with a span classification head.
1067
+
1068
+ This is used for extractive question-answering tasks like SQuAD (a linear
1069
+ layers on top of the hidden states' output to compute `span start logits`
1070
+ and `span end logits`).
1071
+ """
1072
+ #TBD: Push in future commit
bert_padding.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
5
+ # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
6
+
7
+ """Helper functions for padding and unpadding batches.
8
+
9
+ These functions are used extensively throughout the Mosaic BERT implementation
10
+ in `bert_layers.py`.
11
+ """
12
+
13
+ from typing import Tuple, cast
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from einops import rearrange, repeat
18
+
19
+
20
+ class IndexFirstAxis(torch.autograd.Function):
21
+
22
+ @staticmethod
23
+ def forward(ctx, input: torch.Tensor,
24
+ indices: torch.Tensor) -> torch.Tensor:
25
+ """Get just the values of `input` which are at `indices`.
26
+
27
+ Arguments:
28
+ ctx: the autograd context object
29
+ input: (b, ...) 2+ dimensional tensor
30
+ indices: (num_idx) 1D tensor
31
+ """
32
+ ctx.save_for_backward(indices)
33
+ assert input.ndim >= 2
34
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[
35
+ 1:] # type: ignore
36
+ second_dim = other_shape.numel(
37
+ ) # product of sizes of all but first dimension
38
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
39
+ return torch.gather(
40
+ rearrange(input, 'b ... -> b (...)'), # (b, ...) -> (b, second_dim)
41
+ 0,
42
+ repeat(indices, 'z -> z d',
43
+ d=second_dim) # (indices,) -> (indices, second_dim)
44
+ ).reshape(-1, *other_shape) # (num_idx, ...)
45
+
46
+ @staticmethod
47
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
48
+ indices, = ctx.saved_tensors
49
+ assert grad_output.ndim >= 2
50
+ other_shape = grad_output.shape[1:]
51
+ grad_output = rearrange(grad_output, 'b ... -> b (...)')
52
+ grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
53
+ device=grad_output.device,
54
+ dtype=grad_output.dtype)
55
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
56
+ # grad_input[indices] = grad_output
57
+ grad_input.scatter_(0,
58
+ repeat(indices, 'z -> z d', d=grad_output.shape[1]),
59
+ grad_output)
60
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
61
+
62
+
63
+ index_first_axis = IndexFirstAxis.apply
64
+
65
+
66
+ class IndexPutFirstAxis(torch.autograd.Function):
67
+
68
+ @staticmethod
69
+ def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
70
+ first_axis_dim) -> torch.Tensor:
71
+ ctx.save_for_backward(indices)
72
+ assert indices.ndim == 1
73
+ assert values.ndim >= 2
74
+ output = torch.zeros(first_axis_dim,
75
+ *values.shape[1:],
76
+ device=values.device,
77
+ dtype=values.dtype)
78
+ output[indices] = values
79
+ return output
80
+
81
+ @staticmethod
82
+ def backward(ctx,
83
+ grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
84
+ indices, = ctx.saved_tensors
85
+ grad_values = grad_output[indices]
86
+ return grad_values, None, None
87
+
88
+
89
+ index_put_first_axis = IndexPutFirstAxis.apply
90
+
91
+
92
+ def unpad_input(
93
+ hidden_states: torch.Tensor,
94
+ attention_mask: torch.Tensor,
95
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
96
+ """Remove padding from input sequences.
97
+
98
+ Arguments:
99
+ hidden_states: (batch, seqlen, ...)
100
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
101
+
102
+ Returns:
103
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
104
+ indices: (total_nnz)
105
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
106
+ max_seqlen_in_batch: int ()
107
+ """
108
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
109
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
110
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
111
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
112
+ (1, 0))
113
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
114
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
115
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
116
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
117
+ # so we write custom forward and backward to make it a bit faster.
118
+ hidden_states = cast(
119
+ torch.Tensor,
120
+ index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
121
+ indices))
122
+ return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
123
+
124
+
125
+ def unpad_input_only(
126
+ hidden_states: torch.Tensor,
127
+ attention_mask: torch.Tensor,
128
+ ) -> torch.Tensor:
129
+ """Like unpad_input, but only return the unpadded first tensor.
130
+
131
+ Save a small amount of overhead.
132
+
133
+ Arguments:
134
+ hidden_states: (batch, seqlen, ...)
135
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
136
+
137
+ Returns:
138
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
139
+ """
140
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
141
+ return index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
142
+ indices)
143
+
144
+
145
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
146
+ seqlen: int) -> torch.Tensor:
147
+ """Add padding to sequences.
148
+
149
+ Arguments:
150
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
151
+ indices: (total_nnz)
152
+ batch: int batch_size
153
+ seqlen: int max sequence length
154
+
155
+ Returns:
156
+ hidden_states: (batch, seqlen, ...)
157
+ """
158
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
159
+ return rearrange(output, '(b s) ... -> b s ...', b=batch)
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "hum-lodestone-v1",
3
+ "alibi_starting_size": 4096,
4
+ "architectures": [
5
+ "BertModel"
6
+ ],
7
+ "attention_probs_dropout_prob": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_bert.BertConfig",
10
+ "AutoModel": "bert_layers.BertModel",
11
+ "AutoModelForMaskedLM": "bert_layers.BertForMaskedLM"
12
+ },
13
+ "classifier_dropout": null,
14
+ "gradient_checkpointing": false,
15
+ "hidden_act": "gelu",
16
+ "hidden_dropout_prob": 0.1,
17
+ "hidden_size": 768,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 3072,
20
+ "layer_norm_eps": 1e-12,
21
+ "max_position_embeddings": 512,
22
+ "model_type": "bert",
23
+ "num_attention_heads": 12,
24
+ "num_hidden_layers": 12,
25
+ "pad_token_id": 0,
26
+ "position_embedding_type": "absolute",
27
+ "tokenizer_class": "BertTokenizerFast",
28
+ "torch_dtype": "bfloat16",
29
+ "transformers_version": "4.28.1",
30
+ "type_vocab_size": 2,
31
+ "use_cache": true,
32
+ "vocab_size": 30528
33
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.2.2",
4
+ "transformers": "4.28.1",
5
+ "pytorch": "2.0.1+cu117"
6
+ }
7
+ }
configuration_bert.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from transformers import BertConfig as TransformersBertConfig
5
+
6
+
7
+ class BertConfig(TransformersBertConfig):
8
+
9
+ def __init__(
10
+ self,
11
+ alibi_starting_size: int = 512,
12
+ attention_probs_dropout_prob: float = 0.0,
13
+ **kwargs,
14
+ ):
15
+ """Configuration class for MosaicBert.
16
+
17
+ Args:
18
+ alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
19
+ create when initializing the model. You should be able to ignore this parameter in most cases.
20
+ Defaults to 512.
21
+ attention_probs_dropout_prob (float): By default, turn off attention dropout in Mosaic BERT
22
+ (otherwise, Flash Attention will be off by default). Defaults to 0.0.
23
+ """
24
+ super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
25
+ self.alibi_starting_size = alibi_starting_size
data_records.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"AllNLI.json.gz": 277230, "CodeSearchNet.json.gz": 1375067, "NQ-train_pairs.json.gz": 100231, "PAQ_pairs.json.gz": 64371441, "S2ORC_citation_pairs.json.gz": 52603982, "S2ORC_citations_abstracts.json.gz": 252102397, "S2ORC_title_abstract.json.gz": 41769185, "SimpleWiki.json.gz": 102225, "TriviaQA_pairs.json.gz": 73346, "WikiAnswers_pairs.json.gz": 77427422, "agnews.json.gz": 1157745, "altlex.json.gz": 112696, "amazon-qa.json.gz": 2507114, "amazon_review_2018.json.gz": 87877725, "ccnews_title_text.json.gz": 614664, "cnn_dailymail.json.gz": 311971, "coco_captions.json.gz": 828395, "eli5_question_answer.json.gz": 325475, "fever_train.json.gz": 139051, "flickr30k_captions.json.gz": 317695, "gooaq_pairs.json.gz": 3012496, "msmarco-query_passage.json.gz": 532751, "msmarco-query_passage_negative.json.gz": 9144553, "npr.json.gz": 594384, "quora_duplicates.json.gz": 103663, "quora_duplicates_triplets.json.gz": 103663, "reddit-title-body/reddit_title_text_2010.json.gz": 431782, "reddit-title-body/reddit_title_text_2011.json.gz": 1673264, "reddit-title-body/reddit_title_text_2012.json.gz": 3727526, "reddit-title-body/reddit_title_text_2013.json.gz": 5713956, "reddit-title-body/reddit_title_text_2014.json.gz": 8538976, "reddit-title-body/reddit_title_text_2015.json.gz": 11064453, "reddit-title-body/reddit_title_text_2016.json.gz": 12224789, "reddit-title-body/reddit_title_text_2017.json.gz": 13558139, "reddit-title-body/reddit_title_text_2018.json.gz": 15552110, "reddit-title-body/reddit_title_text_2019.json.gz": 19224970, "reddit-title-body/reddit_title_text_2020.json.gz": 23030988, "reddit-title-body/reddit_title_text_2021.json.gz": 12704958, "reddit_2015.json.gz": 135108166, "reddit_2016.json.gz": 159164386, "reddit_2017.json.gz": 191485219, "reddit_2018.json.gz": 240726659, "searchQA_question_top5_snippets_merged.json.gz": 582261, "searchQA_question_topSnippet.json.gz": 117384, "sentence-compression.json.gz": 180000, "specter_train_triples.json.gz": 684100, "squad_pairs.json.gz": 87599, "stackexchange_duplicate_questions_body_body.json.gz": 250459, "stackexchange_duplicate_questions_title-body_title-body.json.gz": 250518, "stackexchange_duplicate_questions_title_title.json.gz": 304524, "stackexchange_title_best_voted_answer_jsonl/3dprinting.stackexchange.com.json.gz": 3488, "stackexchange_title_best_voted_answer_jsonl/academia.stackexchange.com.json.gz": 32137, "stackexchange_title_best_voted_answer_jsonl/ai.stackexchange.com.json.gz": 5763, "stackexchange_title_best_voted_answer_jsonl/android.stackexchange.com.json.gz": 38077, "stackexchange_title_best_voted_answer_jsonl/anime.stackexchange.com.json.gz": 10131, "stackexchange_title_best_voted_answer_jsonl/apple.stackexchange.com.json.gz": 92487, "stackexchange_title_best_voted_answer_jsonl/arduino.stackexchange.com.json.gz": 16281, "stackexchange_title_best_voted_answer_jsonl/askubuntu.com.json.gz": 267135, "stackexchange_title_best_voted_answer_jsonl/astronomy.stackexchange.com.json.gz": 9086, "stackexchange_title_best_voted_answer_jsonl/aviation.stackexchange.com.json.gz": 18755, "stackexchange_title_best_voted_answer_jsonl/avp.stackexchange.com.json.gz": 6450, "stackexchange_title_best_voted_answer_jsonl/beer.stackexchange.com.json.gz": 1012, "stackexchange_title_best_voted_answer_jsonl/bicycles.stackexchange.com.json.gz": 15708, "stackexchange_title_best_voted_answer_jsonl/bioinformatics.stackexchange.com.json.gz": 3135, "stackexchange_title_best_voted_answer_jsonl/biology.stackexchange.com.json.gz": 19277, "stackexchange_title_best_voted_answer_jsonl/bitcoin.stackexchange.com.json.gz": 22474, "stackexchange_title_best_voted_answer_jsonl/blender.stackexchange.com.json.gz": 54153, "stackexchange_title_best_voted_answer_jsonl/boardgames.stackexchange.com.json.gz": 11805, "stackexchange_title_best_voted_answer_jsonl/bricks.stackexchange.com.json.gz": 3530, "stackexchange_title_best_voted_answer_jsonl/buddhism.stackexchange.com.json.gz": 6787, "stackexchange_title_best_voted_answer_jsonl/cardano.stackexchange.com.json.gz": 248, "stackexchange_title_best_voted_answer_jsonl/chemistry.stackexchange.com.json.gz": 27061, "stackexchange_title_best_voted_answer_jsonl/chess.stackexchange.com.json.gz": 6392, "stackexchange_title_best_voted_answer_jsonl/chinese.stackexchange.com.json.gz": 8646, "stackexchange_title_best_voted_answer_jsonl/christianity.stackexchange.com.json.gz": 11498, "stackexchange_title_best_voted_answer_jsonl/civicrm.stackexchange.com.json.gz": 10648, "stackexchange_title_best_voted_answer_jsonl/codegolf.stackexchange.com.json.gz": 8211, "stackexchange_title_best_voted_answer_jsonl/codereview.stackexchange.com.json.gz": 41748, "stackexchange_title_best_voted_answer_jsonl/coffee.stackexchange.com.json.gz": 1188, "stackexchange_title_best_voted_answer_jsonl/cogsci.stackexchange.com.json.gz": 5101, "stackexchange_title_best_voted_answer_jsonl/computergraphics.stackexchange.com.json.gz": 2306, "stackexchange_title_best_voted_answer_jsonl/conlang.stackexchange.com.json.gz": 334, "stackexchange_title_best_voted_answer_jsonl/cooking.stackexchange.com.json.gz": 22641, "stackexchange_title_best_voted_answer_jsonl/craftcms.stackexchange.com.json.gz": 11236, "stackexchange_title_best_voted_answer_jsonl/crafts.stackexchange.com.json.gz": 1659, "stackexchange_title_best_voted_answer_jsonl/crypto.stackexchange.com.json.gz": 19404, "stackexchange_title_best_voted_answer_jsonl/cs.stackexchange.com.json.gz": 30010, "stackexchange_title_best_voted_answer_jsonl/cseducators.stackexchange.com.json.gz": 902, "stackexchange_title_best_voted_answer_jsonl/cstheory.stackexchange.com.json.gz": 7742, "stackexchange_title_best_voted_answer_jsonl/datascience.stackexchange.com.json.gz": 20503, "stackexchange_title_best_voted_answer_jsonl/dba.stackexchange.com.json.gz": 71449, "stackexchange_title_best_voted_answer_jsonl/devops.stackexchange.com.json.gz": 3462, "stackexchange_title_best_voted_answer_jsonl/diy.stackexchange.com.json.gz": 52896, "stackexchange_title_best_voted_answer_jsonl/drones.stackexchange.com.json.gz": 496, "stackexchange_title_best_voted_answer_jsonl/drupal.stackexchange.com.json.gz": 67817, "stackexchange_title_best_voted_answer_jsonl/dsp.stackexchange.com.json.gz": 17430, "stackexchange_title_best_voted_answer_jsonl/earthscience.stackexchange.com.json.gz": 4396, "stackexchange_title_best_voted_answer_jsonl/ebooks.stackexchange.com.json.gz": 1107, "stackexchange_title_best_voted_answer_jsonl/economics.stackexchange.com.json.gz": 8844, "stackexchange_title_best_voted_answer_jsonl/electronics.stackexchange.com.json.gz": 129494, "stackexchange_title_best_voted_answer_jsonl/ell.stackexchange.com.json.gz": 77892, "stackexchange_title_best_voted_answer_jsonl/emacs.stackexchange.com.json.gz": 16830, "stackexchange_title_best_voted_answer_jsonl/engineering.stackexchange.com.json.gz": 8649, "stackexchange_title_best_voted_answer_jsonl/english.stackexchange.com.json.gz": 100640, "stackexchange_title_best_voted_answer_jsonl/eosio.stackexchange.com.json.gz": 1940, "stackexchange_title_best_voted_answer_jsonl/esperanto.stackexchange.com.json.gz": 1466, "stackexchange_title_best_voted_answer_jsonl/ethereum.stackexchange.com.json.gz": 26124, "stackexchange_title_best_voted_answer_jsonl/expatriates.stackexchange.com.json.gz": 4913, "stackexchange_title_best_voted_answer_jsonl/expressionengine.stackexchange.com.json.gz": 10742, "stackexchange_title_best_voted_answer_jsonl/fitness.stackexchange.com.json.gz": 8297, "stackexchange_title_best_voted_answer_jsonl/freelancing.stackexchange.com.json.gz": 1663, "stackexchange_title_best_voted_answer_jsonl/french.stackexchange.com.json.gz": 10578, "stackexchange_title_best_voted_answer_jsonl/gamedev.stackexchange.com.json.gz": 40154, "stackexchange_title_best_voted_answer_jsonl/gaming.stackexchange.com.json.gz": 82887, "stackexchange_title_best_voted_answer_jsonl/gardening.stackexchange.com.json.gz": 13246, "stackexchange_title_best_voted_answer_jsonl/genealogy.stackexchange.com.json.gz": 2895, "stackexchange_title_best_voted_answer_jsonl/german.stackexchange.com.json.gz": 13733, "stackexchange_title_best_voted_answer_jsonl/gis.stackexchange.com.json.gz": 100254, "stackexchange_title_best_voted_answer_jsonl/graphicdesign.stackexchange.com.json.gz": 28083, "stackexchange_title_best_voted_answer_jsonl/ham.stackexchange.com.json.gz": 3501, "stackexchange_title_best_voted_answer_jsonl/hardwarerecs.stackexchange.com.json.gz": 2050, "stackexchange_title_best_voted_answer_jsonl/health.stackexchange.com.json.gz": 4494, "stackexchange_title_best_voted_answer_jsonl/hermeneutics.stackexchange.com.json.gz": 9516, "stackexchange_title_best_voted_answer_jsonl/hinduism.stackexchange.com.json.gz": 8999, "stackexchange_title_best_voted_answer_jsonl/history.stackexchange.com.json.gz": 10766, "stackexchange_title_best_voted_answer_jsonl/homebrew.stackexchange.com.json.gz": 5608, "stackexchange_title_best_voted_answer_jsonl/hsm.stackexchange.com.json.gz": 2517, "stackexchange_title_best_voted_answer_jsonl/interpersonal.stackexchange.com.json.gz": 3398, "stackexchange_title_best_voted_answer_jsonl/iot.stackexchange.com.json.gz": 1359, "stackexchange_title_best_voted_answer_jsonl/iota.stackexchange.com.json.gz": 775, "stackexchange_title_best_voted_answer_jsonl/islam.stackexchange.com.json.gz": 10052, "stackexchange_title_best_voted_answer_jsonl/italian.stackexchange.com.json.gz": 3101, "stackexchange_title_best_voted_answer_jsonl/ja.stackoverflow.com.json.gz": 17376, "stackexchange_title_best_voted_answer_jsonl/japanese.stackexchange.com.json.gz": 20948, "stackexchange_title_best_voted_answer_jsonl/joomla.stackexchange.com.json.gz": 5887, "stackexchange_title_best_voted_answer_jsonl/judaism.stackexchange.com.json.gz": 26085, "stackexchange_title_best_voted_answer_jsonl/korean.stackexchange.com.json.gz": 1406, "stackexchange_title_best_voted_answer_jsonl/languagelearning.stackexchange.com.json.gz": 948, "stackexchange_title_best_voted_answer_jsonl/latin.stackexchange.com.json.gz": 3969, "stackexchange_title_best_voted_answer_jsonl/law.stackexchange.com.json.gz": 16133, "stackexchange_title_best_voted_answer_jsonl/lifehacks.stackexchange.com.json.gz": 2576, "stackexchange_title_best_voted_answer_jsonl/linguistics.stackexchange.com.json.gz": 6843, "stackexchange_title_best_voted_answer_jsonl/literature.stackexchange.com.json.gz": 3539, "stackexchange_title_best_voted_answer_jsonl/magento.stackexchange.com.json.gz": 79241, "stackexchange_title_best_voted_answer_jsonl/martialarts.stackexchange.com.json.gz": 1737, "stackexchange_title_best_voted_answer_jsonl/materials.stackexchange.com.json.gz": 1101, "stackexchange_title_best_voted_answer_jsonl/matheducators.stackexchange.com.json.gz": 2706, "stackexchange_title_best_voted_answer_jsonl/mathematica.stackexchange.com.json.gz": 59895, "stackexchange_title_best_voted_answer_jsonl/mathoverflow.net.json.gz": 85289, "stackexchange_title_best_voted_answer_jsonl/mechanics.stackexchange.com.json.gz": 18613, "stackexchange_title_best_voted_answer_jsonl/meta.askubuntu.com.json.gz": 4268, "stackexchange_title_best_voted_answer_jsonl/meta.mathoverflow.net.json.gz": 1000, "stackexchange_title_best_voted_answer_jsonl/meta.serverfault.com.json.gz": 1726, "stackexchange_title_best_voted_answer_jsonl/meta.stackexchange.com.json.gz": 60744, "stackexchange_title_best_voted_answer_jsonl/meta.stackoverflow.com.json.gz": 24044, "stackexchange_title_best_voted_answer_jsonl/meta.superuser.com.json.gz": 3629, "stackexchange_title_best_voted_answer_jsonl/moderators.stackexchange.com.json.gz": 504, "stackexchange_title_best_voted_answer_jsonl/money.stackexchange.com.json.gz": 29404, "stackexchange_title_best_voted_answer_jsonl/movies.stackexchange.com.json.gz": 18243, "stackexchange_title_best_voted_answer_jsonl/music.stackexchange.com.json.gz": 19936, "stackexchange_title_best_voted_answer_jsonl/musicfans.stackexchange.com.json.gz": 2431, "stackexchange_title_best_voted_answer_jsonl/mythology.stackexchange.com.json.gz": 1595, "stackexchange_title_best_voted_answer_jsonl/networkengineering.stackexchange.com.json.gz": 12590, "stackexchange_title_best_voted_answer_jsonl/opendata.stackexchange.com.json.gz": 3842, "stackexchange_title_best_voted_answer_jsonl/opensource.stackexchange.com.json.gz": 3221, "stackexchange_title_best_voted_answer_jsonl/or.stackexchange.com.json.gz": 1490, "stackexchange_title_best_voted_answer_jsonl/outdoors.stackexchange.com.json.gz": 5278, "stackexchange_title_best_voted_answer_jsonl/parenting.stackexchange.com.json.gz": 5998, "stackexchange_title_best_voted_answer_jsonl/patents.stackexchange.com.json.gz": 3573, "stackexchange_title_best_voted_answer_jsonl/pets.stackexchange.com.json.gz": 6156, "stackexchange_title_best_voted_answer_jsonl/philosophy.stackexchange.com.json.gz": 13114, "stackexchange_title_best_voted_answer_jsonl/photo.stackexchange.com.json.gz": 23204, "stackexchange_title_best_voted_answer_jsonl/physics.stackexchange.com.json.gz": 141230, "stackexchange_title_best_voted_answer_jsonl/pm.stackexchange.com.json.gz": 5435, "stackexchange_title_best_voted_answer_jsonl/poker.stackexchange.com.json.gz": 1665, "stackexchange_title_best_voted_answer_jsonl/politics.stackexchange.com.json.gz": 11047, "stackexchange_title_best_voted_answer_jsonl/portuguese.stackexchange.com.json.gz": 1964, "stackexchange_title_best_voted_answer_jsonl/pt.stackoverflow.com.json.gz": 103277, "stackexchange_title_best_voted_answer_jsonl/puzzling.stackexchange.com.json.gz": 17448, "stackexchange_title_best_voted_answer_jsonl/quant.stackexchange.com.json.gz": 12933, "stackexchange_title_best_voted_answer_jsonl/quantumcomputing.stackexchange.com.json.gz": 4320, "stackexchange_title_best_voted_answer_jsonl/raspberrypi.stackexchange.com.json.gz": 24143, "stackexchange_title_best_voted_answer_jsonl/retrocomputing.stackexchange.com.json.gz": 3907, "stackexchange_title_best_voted_answer_jsonl/reverseengineering.stackexchange.com.json.gz": 5817, "stackexchange_title_best_voted_answer_jsonl/robotics.stackexchange.com.json.gz": 4648, "stackexchange_title_best_voted_answer_jsonl/rpg.stackexchange.com.json.gz": 40435, "stackexchange_title_best_voted_answer_jsonl/ru.stackoverflow.com.json.gz": 253289, "stackexchange_title_best_voted_answer_jsonl/rus.stackexchange.com.json.gz": 16528, "stackexchange_title_best_voted_answer_jsonl/russian.stackexchange.com.json.gz": 3937, "stackexchange_title_best_voted_answer_jsonl/salesforce.stackexchange.com.json.gz": 87272, "stackexchange_title_best_voted_answer_jsonl/scicomp.stackexchange.com.json.gz": 7036, "stackexchange_title_best_voted_answer_jsonl/scifi.stackexchange.com.json.gz": 54805, "stackexchange_title_best_voted_answer_jsonl/serverfault.com.json.gz": 238507, "stackexchange_title_best_voted_answer_jsonl/sharepoint.stackexchange.com.json.gz": 80420, "stackexchange_title_best_voted_answer_jsonl/sitecore.stackexchange.com.json.gz": 7838, "stackexchange_title_best_voted_answer_jsonl/skeptics.stackexchange.com.json.gz": 8145, "stackexchange_title_best_voted_answer_jsonl/softwareengineering.stackexchange.com.json.gz": 51326, "stackexchange_title_best_voted_answer_jsonl/softwarerecs.stackexchange.com.json.gz": 11761, "stackexchange_title_best_voted_answer_jsonl/sound.stackexchange.com.json.gz": 8303, "stackexchange_title_best_voted_answer_jsonl/space.stackexchange.com.json.gz": 12893, "stackexchange_title_best_voted_answer_jsonl/spanish.stackexchange.com.json.gz": 7675, "stackexchange_title_best_voted_answer_jsonl/sports.stackexchange.com.json.gz": 4707, "stackexchange_title_best_voted_answer_jsonl/sqa.stackexchange.com.json.gz": 9256, "stackexchange_title_best_voted_answer_jsonl/stackapps.com.json.gz": 1518, "stackexchange_title_best_voted_answer_jsonl/stats.stackexchange.com.json.gz": 115679, "stackexchange_title_best_voted_answer_jsonl/stellar.stackexchange.com.json.gz": 1078, "stackexchange_title_best_voted_answer_jsonl/superuser.com.json.gz": 352610, "stackexchange_title_best_voted_answer_jsonl/sustainability.stackexchange.com.json.gz": 1674, "stackexchange_title_best_voted_answer_jsonl/tex.stackexchange.com.json.gz": 171628, "stackexchange_title_best_voted_answer_jsonl/tezos.stackexchange.com.json.gz": 1169, "stackexchange_title_best_voted_answer_jsonl/tor.stackexchange.com.json.gz": 4167, "stackexchange_title_best_voted_answer_jsonl/travel.stackexchange.com.json.gz": 36533, "stackexchange_title_best_voted_answer_jsonl/tridion.stackexchange.com.json.gz": 5907, "stackexchange_title_best_voted_answer_jsonl/ukrainian.stackexchange.com.json.gz": 1767, "stackexchange_title_best_voted_answer_jsonl/unix.stackexchange.com.json.gz": 155414, "stackexchange_title_best_voted_answer_jsonl/ux.stackexchange.com.json.gz": 28901, "stackexchange_title_best_voted_answer_jsonl/vegetarianism.stackexchange.com.json.gz": 585, "stackexchange_title_best_voted_answer_jsonl/vi.stackexchange.com.json.gz": 9000, "stackexchange_title_best_voted_answer_jsonl/webapps.stackexchange.com.json.gz": 24867, "stackexchange_title_best_voted_answer_jsonl/webmasters.stackexchange.com.json.gz": 30370, "stackexchange_title_best_voted_answer_jsonl/windowsphone.stackexchange.com.json.gz": 2807, "stackexchange_title_best_voted_answer_jsonl/woodworking.stackexchange.com.json.gz": 2955, "stackexchange_title_best_voted_answer_jsonl/wordpress.stackexchange.com.json.gz": 83621, "stackexchange_title_best_voted_answer_jsonl/workplace.stackexchange.com.json.gz": 24012, "stackexchange_title_best_voted_answer_jsonl/worldbuilding.stackexchange.com.json.gz": 26210, "stackexchange_title_best_voted_answer_jsonl/writers.stackexchange.com.json.gz": 9867, "stackexchange_title_body_jsonl/academia.stackexchange.com.json.gz": 34331, "stackexchange_title_body_jsonl/android.stackexchange.com.json.gz": 51608, "stackexchange_title_body_jsonl/anime.stackexchange.com.json.gz": 11444, "stackexchange_title_body_jsonl/apple.stackexchange.com.json.gz": 110622, "stackexchange_title_body_jsonl/arduino.stackexchange.com.json.gz": 19553, "stackexchange_title_body_jsonl/askubuntu.com.json.gz": 347925, "stackexchange_title_body_jsonl/astronomy.stackexchange.com.json.gz": 10462, "stackexchange_title_body_jsonl/aviation.stackexchange.com.json.gz": 20139, "stackexchange_title_body_jsonl/bicycles.stackexchange.com.json.gz": 16353, "stackexchange_title_body_jsonl/biology.stackexchange.com.json.gz": 24447, "stackexchange_title_body_jsonl/bitcoin.stackexchange.com.json.gz": 25374, "stackexchange_title_body_jsonl/blender.stackexchange.com.json.gz": 80766, "stackexchange_title_body_jsonl/boardgames.stackexchange.com.json.gz": 12149, "stackexchange_title_body_jsonl/chemistry.stackexchange.com.json.gz": 34506, "stackexchange_title_body_jsonl/christianity.stackexchange.com.json.gz": 12108, "stackexchange_title_body_jsonl/civicrm.stackexchange.com.json.gz": 12543, "stackexchange_title_body_jsonl/codereview.stackexchange.com.json.gz": 45765, "stackexchange_title_body_jsonl/cooking.stackexchange.com.json.gz": 23705, "stackexchange_title_body_jsonl/craftcms.stackexchange.com.json.gz": 12574, "stackexchange_title_body_jsonl/crypto.stackexchange.com.json.gz": 23231, "stackexchange_title_body_jsonl/cs.stackexchange.com.json.gz": 38314, "stackexchange_title_body_jsonl/cstheory.stackexchange.com.json.gz": 10642, "stackexchange_title_body_jsonl/datascience.stackexchange.com.json.gz": 27397, "stackexchange_title_body_jsonl/dba.stackexchange.com.json.gz": 81871, "stackexchange_title_body_jsonl/diy.stackexchange.com.json.gz": 60083, "stackexchange_title_body_jsonl/drupal.stackexchange.com.json.gz": 79717, "stackexchange_title_body_jsonl/dsp.stackexchange.com.json.gz": 21252, "stackexchange_title_body_jsonl/economics.stackexchange.com.json.gz": 11115, "stackexchange_title_body_jsonl/electronics.stackexchange.com.json.gz": 143582, "stackexchange_title_body_jsonl/ell.stackexchange.com.json.gz": 83271, "stackexchange_title_body_jsonl/emacs.stackexchange.com.json.gz": 21055, "stackexchange_title_body_jsonl/engineering.stackexchange.com.json.gz": 10753, "stackexchange_title_body_jsonl/english.stackexchange.com.json.gz": 109522, "stackexchange_title_body_jsonl/ethereum.stackexchange.com.json.gz": 32760, "stackexchange_title_body_jsonl/expressionengine.stackexchange.com.json.gz": 11866, "stackexchange_title_body_jsonl/french.stackexchange.com.json.gz": 10794, "stackexchange_title_body_jsonl/gamedev.stackexchange.com.json.gz": 46485, "stackexchange_title_body_jsonl/gaming.stackexchange.com.json.gz": 88912, "stackexchange_title_body_jsonl/gardening.stackexchange.com.json.gz": 15136, "stackexchange_title_body_jsonl/german.stackexchange.com.json.gz": 13950, "stackexchange_title_body_jsonl/gis.stackexchange.com.json.gz": 131000, "stackexchange_title_body_jsonl/graphicdesign.stackexchange.com.json.gz": 30233, "stackexchange_title_body_jsonl/hinduism.stackexchange.com.json.gz": 13450, "stackexchange_title_body_jsonl/history.stackexchange.com.json.gz": 12021, "stackexchange_title_body_jsonl/islam.stackexchange.com.json.gz": 11853, "stackexchange_title_body_jsonl/japanese.stackexchange.com.json.gz": 22056, "stackexchange_title_body_jsonl/judaism.stackexchange.com.json.gz": 32028, "stackexchange_title_body_jsonl/law.stackexchange.com.json.gz": 17941, "stackexchange_title_body_jsonl/magento.stackexchange.com.json.gz": 99991, "stackexchange_title_body_jsonl/math.stackexchange.com.json.gz": 1338443, "stackexchange_title_body_jsonl/mathematica.stackexchange.com.json.gz": 73131, "stackexchange_title_body_jsonl/mathoverflow.net.json.gz": 120851, "stackexchange_title_body_jsonl/mechanics.stackexchange.com.json.gz": 22868, "stackexchange_title_body_jsonl/meta.stackexchange.com.json.gz": 83510, "stackexchange_title_body_jsonl/meta.stackoverflow.com.json.gz": 36456, "stackexchange_title_body_jsonl/money.stackexchange.com.json.gz": 32021, "stackexchange_title_body_jsonl/movies.stackexchange.com.json.gz": 20181, "stackexchange_title_body_jsonl/music.stackexchange.com.json.gz": 20636, "stackexchange_title_body_jsonl/networkengineering.stackexchange.com.json.gz": 13454, "stackexchange_title_body_jsonl/philosophy.stackexchange.com.json.gz": 14829, "stackexchange_title_body_jsonl/photo.stackexchange.com.json.gz": 23753, "stackexchange_title_body_jsonl/physics.stackexchange.com.json.gz": 173307, "stackexchange_title_body_jsonl/politics.stackexchange.com.json.gz": 11894, "stackexchange_title_body_jsonl/puzzling.stackexchange.com.json.gz": 17851, "stackexchange_title_body_jsonl/quant.stackexchange.com.json.gz": 17261, "stackexchange_title_body_jsonl/raspberrypi.stackexchange.com.json.gz": 30625, "stackexchange_title_body_jsonl/rpg.stackexchange.com.json.gz": 42303, "stackexchange_title_body_jsonl/rus.stackexchange.com.json.gz": 16871, "stackexchange_title_body_jsonl/salesforce.stackexchange.com.json.gz": 105260, "stackexchange_title_body_jsonl/scifi.stackexchange.com.json.gz": 61528, "stackexchange_title_body_jsonl/sharepoint.stackexchange.com.json.gz": 94011, "stackexchange_title_body_jsonl/skeptics.stackexchange.com.json.gz": 10009, "stackexchange_title_body_jsonl/small_stackexchanges.json.gz": 448146, "stackexchange_title_body_jsonl/softwareengineering.stackexchange.com.json.gz": 53942, "stackexchange_title_body_jsonl/softwarerecs.stackexchange.com.json.gz": 20142, "stackexchange_title_body_jsonl/space.stackexchange.com.json.gz": 15142, "stackexchange_title_body_jsonl/stackoverflow.com-Posts.json.gz": 18562443, "stackexchange_title_body_jsonl/stats.stackexchange.com.json.gz": 173466, "stackexchange_title_body_jsonl/superuser.com.json.gz": 435463, "stackexchange_title_body_jsonl/tex.stackexchange.com.json.gz": 202954, "stackexchange_title_body_jsonl/travel.stackexchange.com.json.gz": 41227, "stackexchange_title_body_jsonl/unix.stackexchange.com.json.gz": 185997, "stackexchange_title_body_jsonl/ux.stackexchange.com.json.gz": 29403, "stackexchange_title_body_jsonl/vi.stackexchange.com.json.gz": 10551, "stackexchange_title_body_jsonl/webapps.stackexchange.com.json.gz": 29697, "stackexchange_title_body_jsonl/webmasters.stackexchange.com.json.gz": 34559, "stackexchange_title_body_jsonl/wordpress.stackexchange.com.json.gz": 100474, "stackexchange_title_body_jsonl/workplace.stackexchange.com.json.gz": 24189, "stackexchange_title_body_jsonl/worldbuilding.stackexchange.com.json.gz": 26763, "stackexchange_title_body_jsonl/writers.stackexchange.com.json.gz": 10157, "stackexchange_title_body_small.json.gz": 364000, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/3dprinting.stackexchange.com.json.gz": 109, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/academia.stackexchange.com.json.gz": 2465, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ai.stackexchange.com.json.gz": 130, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/android.stackexchange.com.json.gz": 2830, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/anime.stackexchange.com.json.gz": 802, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/apple.stackexchange.com.json.gz": 6696, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/arduino.stackexchange.com.json.gz": 595, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/askubuntu.com.json.gz": 9975, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/astronomy.stackexchange.com.json.gz": 371, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/aviation.stackexchange.com.json.gz": 903, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/avp.stackexchange.com.json.gz": 152, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/beer.stackexchange.com.json.gz": 57, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/bicycles.stackexchange.com.json.gz": 984, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/bioinformatics.stackexchange.com.json.gz": 39, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/biology.stackexchange.com.json.gz": 832, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/bitcoin.stackexchange.com.json.gz": 1068, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/blender.stackexchange.com.json.gz": 1312, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/boardgames.stackexchange.com.json.gz": 691, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/bricks.stackexchange.com.json.gz": 79, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/buddhism.stackexchange.com.json.gz": 770, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cardano.stackexchange.com.json.gz": 7, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/chemistry.stackexchange.com.json.gz": 1523, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/chess.stackexchange.com.json.gz": 402, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/chinese.stackexchange.com.json.gz": 611, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/christianity.stackexchange.com.json.gz": 1502, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/civicrm.stackexchange.com.json.gz": 85, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/codegolf.stackexchange.com.json.gz": 333, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/codereview.stackexchange.com.json.gz": 666, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/coffee.stackexchange.com.json.gz": 47, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cogsci.stackexchange.com.json.gz": 221, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/computergraphics.stackexchange.com.json.gz": 30, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/conlang.stackexchange.com.json.gz": 8, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cooking.stackexchange.com.json.gz": 2064, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/craftcms.stackexchange.com.json.gz": 26, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/crafts.stackexchange.com.json.gz": 72, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/crypto.stackexchange.com.json.gz": 595, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cs.stackexchange.com.json.gz": 936, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cseducators.stackexchange.com.json.gz": 67, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cstheory.stackexchange.com.json.gz": 314, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/datascience.stackexchange.com.json.gz": 325, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/dba.stackexchange.com.json.gz": 2502, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/devops.stackexchange.com.json.gz": 53, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/diy.stackexchange.com.json.gz": 2037, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/drones.stackexchange.com.json.gz": 6, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/drupal.stackexchange.com.json.gz": 1714, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/dsp.stackexchange.com.json.gz": 387, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/earthscience.stackexchange.com.json.gz": 229, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ebooks.stackexchange.com.json.gz": 54, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/economics.stackexchange.com.json.gz": 441, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/electronics.stackexchange.com.json.gz": 4014, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/elementaryos.stackexchange.com.json.gz": 224, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ell.stackexchange.com.json.gz": 4438, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/emacs.stackexchange.com.json.gz": 188, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/engineering.stackexchange.com.json.gz": 227, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/english.stackexchange.com.json.gz": 13003, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/eosio.stackexchange.com.json.gz": 44, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/esperanto.stackexchange.com.json.gz": 56, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ethereum.stackexchange.com.json.gz": 479, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/expatriates.stackexchange.com.json.gz": 132, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/expressionengine.stackexchange.com.json.gz": 91, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/fitness.stackexchange.com.json.gz": 567, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/freelancing.stackexchange.com.json.gz": 70, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/french.stackexchange.com.json.gz": 632, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/gamedev.stackexchange.com.json.gz": 1598, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/gaming.stackexchange.com.json.gz": 7321, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/gardening.stackexchange.com.json.gz": 210, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/genealogy.stackexchange.com.json.gz": 86, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/german.stackexchange.com.json.gz": 1047, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/gis.stackexchange.com.json.gz": 1843, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/graphicdesign.stackexchange.com.json.gz": 1565, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ham.stackexchange.com.json.gz": 158, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/hardwarerecs.stackexchange.com.json.gz": 58, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/health.stackexchange.com.json.gz": 299, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/hermeneutics.stackexchange.com.json.gz": 1719, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/hinduism.stackexchange.com.json.gz": 343, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/history.stackexchange.com.json.gz": 1099, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/homebrew.stackexchange.com.json.gz": 176, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/hsm.stackexchange.com.json.gz": 70, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/interpersonal.stackexchange.com.json.gz": 469, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/iot.stackexchange.com.json.gz": 10, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/iota.stackexchange.com.json.gz": 31, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/islam.stackexchange.com.json.gz": 2037, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/italian.stackexchange.com.json.gz": 181, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ja.stackoverflow.com.json.gz": 328, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/japanese.stackexchange.com.json.gz": 1124, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/joomla.stackexchange.com.json.gz": 124, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/judaism.stackexchange.com.json.gz": 2216, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/korean.stackexchange.com.json.gz": 28, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/languagelearning.stackexchange.com.json.gz": 42, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/latin.stackexchange.com.json.gz": 55, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/law.stackexchange.com.json.gz": 1297, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/lifehacks.stackexchange.com.json.gz": 316, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/linguistics.stackexchange.com.json.gz": 442, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/literature.stackexchange.com.json.gz": 191, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/magento.stackexchange.com.json.gz": 1849, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/martialarts.stackexchange.com.json.gz": 254, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/materials.stackexchange.com.json.gz": 1, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/matheducators.stackexchange.com.json.gz": 177, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/mathematica.stackexchange.com.json.gz": 262, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/mathoverflow.net.json.gz": 1109, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/mechanics.stackexchange.com.json.gz": 842, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.askubuntu.com.json.gz": 252, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.mathoverflow.net.json.gz": 61, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.serverfault.com.json.gz": 114, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.stackexchange.com.json.gz": 2517, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.stackoverflow.com.json.gz": 2678, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.superuser.com.json.gz": 145, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/moderators.stackexchange.com.json.gz": 23, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/monero.stackexchange.com.json.gz": 26, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/money.stackexchange.com.json.gz": 1905, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/movies.stackexchange.com.json.gz": 1577, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/music.stackexchange.com.json.gz": 1228, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/musicfans.stackexchange.com.json.gz": 78, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/mythology.stackexchange.com.json.gz": 103, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/networkengineering.stackexchange.com.json.gz": 476, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/opendata.stackexchange.com.json.gz": 45, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/opensource.stackexchange.com.json.gz": 123, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/or.stackexchange.com.json.gz": 13, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/outdoors.stackexchange.com.json.gz": 221, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/parenting.stackexchange.com.json.gz": 624, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/patents.stackexchange.com.json.gz": 137, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/pets.stackexchange.com.json.gz": 322, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/philosophy.stackexchange.com.json.gz": 1184, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/photo.stackexchange.com.json.gz": 1432, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/physics.stackexchange.com.json.gz": 8362, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/pm.stackexchange.com.json.gz": 241, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/poker.stackexchange.com.json.gz": 115, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/politics.stackexchange.com.json.gz": 1468, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/portuguese.stackexchange.com.json.gz": 144, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/pt.stackoverflow.com.json.gz": 3718, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/puzzling.stackexchange.com.json.gz": 784, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/quant.stackexchange.com.json.gz": 340, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/quantumcomputing.stackexchange.com.json.gz": 46, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/raspberrypi.stackexchange.com.json.gz": 1011, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/retrocomputing.stackexchange.com.json.gz": 135, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/reverseengineering.stackexchange.com.json.gz": 97, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/robotics.stackexchange.com.json.gz": 110, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/rpg.stackexchange.com.json.gz": 4212, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ru.stackoverflow.com.json.gz": 6305, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/rus.stackexchange.com.json.gz": 514, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/russian.stackexchange.com.json.gz": 353, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/salesforce.stackexchange.com.json.gz": 1781, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/scicomp.stackexchange.com.json.gz": 127, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/scifi.stackexchange.com.json.gz": 5176, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/security.stackexchange.com.json.gz": 3069, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/serverfault.com.json.gz": 7969, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sharepoint.stackexchange.com.json.gz": 1691, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sitecore.stackexchange.com.json.gz": 122, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/skeptics.stackexchange.com.json.gz": 670, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/softwareengineering.stackexchange.com.json.gz": 4238, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/softwarerecs.stackexchange.com.json.gz": 348, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sound.stackexchange.com.json.gz": 365, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/space.stackexchange.com.json.gz": 405, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/spanish.stackexchange.com.json.gz": 366, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sports.stackexchange.com.json.gz": 455, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sqa.stackexchange.com.json.gz": 353, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/stackapps.com.json.gz": 15, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/stats.stackexchange.com.json.gz": 2238, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/stellar.stackexchange.com.json.gz": 3, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/superuser.com.json.gz": 17425, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sustainability.stackexchange.com.json.gz": 152, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/tex.stackexchange.com.json.gz": 1095, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/tezos.stackexchange.com.json.gz": 11, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/tor.stackexchange.com.json.gz": 137, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/travel.stackexchange.com.json.gz": 1317, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/tridion.stackexchange.com.json.gz": 68, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ukrainian.stackexchange.com.json.gz": 87, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/unix.stackexchange.com.json.gz": 6173, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ux.stackexchange.com.json.gz": 1107, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/vegetarianism.stackexchange.com.json.gz": 35, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/vi.stackexchange.com.json.gz": 95, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/webapps.stackexchange.com.json.gz": 1906, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/webmasters.stackexchange.com.json.gz": 854, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/windowsphone.stackexchange.com.json.gz": 153, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/woodworking.stackexchange.com.json.gz": 93, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/wordpress.stackexchange.com.json.gz": 3046, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/workplace.stackexchange.com.json.gz": 4317, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/worldbuilding.stackexchange.com.json.gz": 2087, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/writers.stackexchange.com.json.gz": 407, "stackexchange_titlebody_best_voted_answer_jsonl/3dprinting.stackexchange.com.json.gz": 3488, "stackexchange_titlebody_best_voted_answer_jsonl/academia.stackexchange.com.json.gz": 32137, "stackexchange_titlebody_best_voted_answer_jsonl/ai.stackexchange.com.json.gz": 5763, "stackexchange_titlebody_best_voted_answer_jsonl/android.stackexchange.com.json.gz": 38077, "stackexchange_titlebody_best_voted_answer_jsonl/anime.stackexchange.com.json.gz": 10131, "stackexchange_titlebody_best_voted_answer_jsonl/apple.stackexchange.com.json.gz": 92487, "stackexchange_titlebody_best_voted_answer_jsonl/arduino.stackexchange.com.json.gz": 16281, "stackexchange_titlebody_best_voted_answer_jsonl/askubuntu.com.json.gz": 267135, "stackexchange_titlebody_best_voted_answer_jsonl/astronomy.stackexchange.com.json.gz": 9086, "stackexchange_titlebody_best_voted_answer_jsonl/aviation.stackexchange.com.json.gz": 18755, "stackexchange_titlebody_best_voted_answer_jsonl/avp.stackexchange.com.json.gz": 6450, "stackexchange_titlebody_best_voted_answer_jsonl/beer.stackexchange.com.json.gz": 1012, "stackexchange_titlebody_best_voted_answer_jsonl/bicycles.stackexchange.com.json.gz": 15708, "stackexchange_titlebody_best_voted_answer_jsonl/bioinformatics.stackexchange.com.json.gz": 3135, "stackexchange_titlebody_best_voted_answer_jsonl/biology.stackexchange.com.json.gz": 19277, "stackexchange_titlebody_best_voted_answer_jsonl/bitcoin.stackexchange.com.json.gz": 22474, "stackexchange_titlebody_best_voted_answer_jsonl/blender.stackexchange.com.json.gz": 54153, "stackexchange_titlebody_best_voted_answer_jsonl/boardgames.stackexchange.com.json.gz": 11805, "stackexchange_titlebody_best_voted_answer_jsonl/bricks.stackexchange.com.json.gz": 3530, "stackexchange_titlebody_best_voted_answer_jsonl/buddhism.stackexchange.com.json.gz": 6787, "stackexchange_titlebody_best_voted_answer_jsonl/cardano.stackexchange.com.json.gz": 248, "stackexchange_titlebody_best_voted_answer_jsonl/chemistry.stackexchange.com.json.gz": 27061, "stackexchange_titlebody_best_voted_answer_jsonl/chess.stackexchange.com.json.gz": 6392, "stackexchange_titlebody_best_voted_answer_jsonl/chinese.stackexchange.com.json.gz": 8646, "stackexchange_titlebody_best_voted_answer_jsonl/christianity.stackexchange.com.json.gz": 11498, "stackexchange_titlebody_best_voted_answer_jsonl/civicrm.stackexchange.com.json.gz": 10648, "stackexchange_titlebody_best_voted_answer_jsonl/codegolf.stackexchange.com.json.gz": 8211, "stackexchange_titlebody_best_voted_answer_jsonl/codereview.stackexchange.com.json.gz": 41748, "stackexchange_titlebody_best_voted_answer_jsonl/coffee.stackexchange.com.json.gz": 1188, "stackexchange_titlebody_best_voted_answer_jsonl/cogsci.stackexchange.com.json.gz": 5101, "stackexchange_titlebody_best_voted_answer_jsonl/computergraphics.stackexchange.com.json.gz": 2306, "stackexchange_titlebody_best_voted_answer_jsonl/conlang.stackexchange.com.json.gz": 334, "stackexchange_titlebody_best_voted_answer_jsonl/cooking.stackexchange.com.json.gz": 22641, "stackexchange_titlebody_best_voted_answer_jsonl/craftcms.stackexchange.com.json.gz": 11236, "stackexchange_titlebody_best_voted_answer_jsonl/crafts.stackexchange.com.json.gz": 1659, "stackexchange_titlebody_best_voted_answer_jsonl/crypto.stackexchange.com.json.gz": 19404, "stackexchange_titlebody_best_voted_answer_jsonl/cs.stackexchange.com.json.gz": 30010, "stackexchange_titlebody_best_voted_answer_jsonl/cseducators.stackexchange.com.json.gz": 902, "stackexchange_titlebody_best_voted_answer_jsonl/cstheory.stackexchange.com.json.gz": 7742, "stackexchange_titlebody_best_voted_answer_jsonl/datascience.stackexchange.com.json.gz": 20503, "stackexchange_titlebody_best_voted_answer_jsonl/dba.stackexchange.com.json.gz": 71449, "stackexchange_titlebody_best_voted_answer_jsonl/devops.stackexchange.com.json.gz": 3462, "stackexchange_titlebody_best_voted_answer_jsonl/diy.stackexchange.com.json.gz": 52896, "stackexchange_titlebody_best_voted_answer_jsonl/drones.stackexchange.com.json.gz": 496, "stackexchange_titlebody_best_voted_answer_jsonl/drupal.stackexchange.com.json.gz": 67817, "stackexchange_titlebody_best_voted_answer_jsonl/dsp.stackexchange.com.json.gz": 17430, "stackexchange_titlebody_best_voted_answer_jsonl/earthscience.stackexchange.com.json.gz": 4396, "stackexchange_titlebody_best_voted_answer_jsonl/ebooks.stackexchange.com.json.gz": 1107, "stackexchange_titlebody_best_voted_answer_jsonl/economics.stackexchange.com.json.gz": 8844, "stackexchange_titlebody_best_voted_answer_jsonl/electronics.stackexchange.com.json.gz": 129494, "stackexchange_titlebody_best_voted_answer_jsonl/elementaryos.stackexchange.com.json.gz": 5917, "stackexchange_titlebody_best_voted_answer_jsonl/ell.stackexchange.com.json.gz": 77892, "stackexchange_titlebody_best_voted_answer_jsonl/emacs.stackexchange.com.json.gz": 16830, "stackexchange_titlebody_best_voted_answer_jsonl/engineering.stackexchange.com.json.gz": 8649, "stackexchange_titlebody_best_voted_answer_jsonl/english.stackexchange.com.json.gz": 100640, "stackexchange_titlebody_best_voted_answer_jsonl/eosio.stackexchange.com.json.gz": 1940, "stackexchange_titlebody_best_voted_answer_jsonl/esperanto.stackexchange.com.json.gz": 1466, "stackexchange_titlebody_best_voted_answer_jsonl/ethereum.stackexchange.com.json.gz": 26124, "stackexchange_titlebody_best_voted_answer_jsonl/expatriates.stackexchange.com.json.gz": 4913, "stackexchange_titlebody_best_voted_answer_jsonl/expressionengine.stackexchange.com.json.gz": 10742, "stackexchange_titlebody_best_voted_answer_jsonl/fitness.stackexchange.com.json.gz": 8297, "stackexchange_titlebody_best_voted_answer_jsonl/freelancing.stackexchange.com.json.gz": 1663, "stackexchange_titlebody_best_voted_answer_jsonl/french.stackexchange.com.json.gz": 10578, "stackexchange_titlebody_best_voted_answer_jsonl/gamedev.stackexchange.com.json.gz": 40154, "stackexchange_titlebody_best_voted_answer_jsonl/gaming.stackexchange.com.json.gz": 82887, "stackexchange_titlebody_best_voted_answer_jsonl/gardening.stackexchange.com.json.gz": 13246, "stackexchange_titlebody_best_voted_answer_jsonl/genealogy.stackexchange.com.json.gz": 2895, "stackexchange_titlebody_best_voted_answer_jsonl/german.stackexchange.com.json.gz": 13733, "stackexchange_titlebody_best_voted_answer_jsonl/gis.stackexchange.com.json.gz": 100254, "stackexchange_titlebody_best_voted_answer_jsonl/graphicdesign.stackexchange.com.json.gz": 28083, "stackexchange_titlebody_best_voted_answer_jsonl/ham.stackexchange.com.json.gz": 3501, "stackexchange_titlebody_best_voted_answer_jsonl/hardwarerecs.stackexchange.com.json.gz": 2050, "stackexchange_titlebody_best_voted_answer_jsonl/health.stackexchange.com.json.gz": 4494, "stackexchange_titlebody_best_voted_answer_jsonl/hermeneutics.stackexchange.com.json.gz": 9516, "stackexchange_titlebody_best_voted_answer_jsonl/hinduism.stackexchange.com.json.gz": 8999, "stackexchange_titlebody_best_voted_answer_jsonl/history.stackexchange.com.json.gz": 10766, "stackexchange_titlebody_best_voted_answer_jsonl/homebrew.stackexchange.com.json.gz": 5608, "stackexchange_titlebody_best_voted_answer_jsonl/hsm.stackexchange.com.json.gz": 2517, "stackexchange_titlebody_best_voted_answer_jsonl/interpersonal.stackexchange.com.json.gz": 3398, "stackexchange_titlebody_best_voted_answer_jsonl/iot.stackexchange.com.json.gz": 1359, "stackexchange_titlebody_best_voted_answer_jsonl/iota.stackexchange.com.json.gz": 775, "stackexchange_titlebody_best_voted_answer_jsonl/islam.stackexchange.com.json.gz": 10052, "stackexchange_titlebody_best_voted_answer_jsonl/italian.stackexchange.com.json.gz": 3101, "stackexchange_titlebody_best_voted_answer_jsonl/ja.stackoverflow.com.json.gz": 17376, "stackexchange_titlebody_best_voted_answer_jsonl/japanese.stackexchange.com.json.gz": 20948, "stackexchange_titlebody_best_voted_answer_jsonl/joomla.stackexchange.com.json.gz": 5887, "stackexchange_titlebody_best_voted_answer_jsonl/judaism.stackexchange.com.json.gz": 26085, "stackexchange_titlebody_best_voted_answer_jsonl/korean.stackexchange.com.json.gz": 1406, "stackexchange_titlebody_best_voted_answer_jsonl/languagelearning.stackexchange.com.json.gz": 948, "stackexchange_titlebody_best_voted_answer_jsonl/latin.stackexchange.com.json.gz": 3969, "stackexchange_titlebody_best_voted_answer_jsonl/law.stackexchange.com.json.gz": 16133, "stackexchange_titlebody_best_voted_answer_jsonl/lifehacks.stackexchange.com.json.gz": 2576, "stackexchange_titlebody_best_voted_answer_jsonl/linguistics.stackexchange.com.json.gz": 6843, "stackexchange_titlebody_best_voted_answer_jsonl/literature.stackexchange.com.json.gz": 3539, "stackexchange_titlebody_best_voted_answer_jsonl/magento.stackexchange.com.json.gz": 79241, "stackexchange_titlebody_best_voted_answer_jsonl/martialarts.stackexchange.com.json.gz": 1737, "stackexchange_titlebody_best_voted_answer_jsonl/materials.stackexchange.com.json.gz": 1101, "stackexchange_titlebody_best_voted_answer_jsonl/matheducators.stackexchange.com.json.gz": 2706, "stackexchange_titlebody_best_voted_answer_jsonl/mathematica.stackexchange.com.json.gz": 59895, "stackexchange_titlebody_best_voted_answer_jsonl/mathoverflow.net.json.gz": 85289, "stackexchange_titlebody_best_voted_answer_jsonl/mechanics.stackexchange.com.json.gz": 18613, "stackexchange_titlebody_best_voted_answer_jsonl/meta.askubuntu.com.json.gz": 4268, "stackexchange_titlebody_best_voted_answer_jsonl/meta.mathoverflow.net.json.gz": 1000, "stackexchange_titlebody_best_voted_answer_jsonl/meta.serverfault.com.json.gz": 1726, "stackexchange_titlebody_best_voted_answer_jsonl/meta.stackexchange.com.json.gz": 60744, "stackexchange_titlebody_best_voted_answer_jsonl/meta.stackoverflow.com.json.gz": 24044, "stackexchange_titlebody_best_voted_answer_jsonl/meta.superuser.com.json.gz": 3629, "stackexchange_titlebody_best_voted_answer_jsonl/moderators.stackexchange.com.json.gz": 504, "stackexchange_titlebody_best_voted_answer_jsonl/money.stackexchange.com.json.gz": 29404, "stackexchange_titlebody_best_voted_answer_jsonl/movies.stackexchange.com.json.gz": 18243, "stackexchange_titlebody_best_voted_answer_jsonl/music.stackexchange.com.json.gz": 19936, "stackexchange_titlebody_best_voted_answer_jsonl/musicfans.stackexchange.com.json.gz": 2431, "stackexchange_titlebody_best_voted_answer_jsonl/mythology.stackexchange.com.json.gz": 1595, "stackexchange_titlebody_best_voted_answer_jsonl/networkengineering.stackexchange.com.json.gz": 12590, "stackexchange_titlebody_best_voted_answer_jsonl/opendata.stackexchange.com.json.gz": 3842, "stackexchange_titlebody_best_voted_answer_jsonl/opensource.stackexchange.com.json.gz": 3221, "stackexchange_titlebody_best_voted_answer_jsonl/or.stackexchange.com.json.gz": 1490, "stackexchange_titlebody_best_voted_answer_jsonl/outdoors.stackexchange.com.json.gz": 5278, "stackexchange_titlebody_best_voted_answer_jsonl/parenting.stackexchange.com.json.gz": 5998, "stackexchange_titlebody_best_voted_answer_jsonl/patents.stackexchange.com.json.gz": 3573, "stackexchange_titlebody_best_voted_answer_jsonl/pets.stackexchange.com.json.gz": 6156, "stackexchange_titlebody_best_voted_answer_jsonl/philosophy.stackexchange.com.json.gz": 13114, "stackexchange_titlebody_best_voted_answer_jsonl/photo.stackexchange.com.json.gz": 23204, "stackexchange_titlebody_best_voted_answer_jsonl/physics.stackexchange.com.json.gz": 141230, "stackexchange_titlebody_best_voted_answer_jsonl/pm.stackexchange.com.json.gz": 5435, "stackexchange_titlebody_best_voted_answer_jsonl/poker.stackexchange.com.json.gz": 1665, "stackexchange_titlebody_best_voted_answer_jsonl/politics.stackexchange.com.json.gz": 11047, "stackexchange_titlebody_best_voted_answer_jsonl/portuguese.stackexchange.com.json.gz": 1964, "stackexchange_titlebody_best_voted_answer_jsonl/pt.stackoverflow.com.json.gz": 103277, "stackexchange_titlebody_best_voted_answer_jsonl/puzzling.stackexchange.com.json.gz": 17448, "stackexchange_titlebody_best_voted_answer_jsonl/quant.stackexchange.com.json.gz": 12933, "stackexchange_titlebody_best_voted_answer_jsonl/quantumcomputing.stackexchange.com.json.gz": 4320, "stackexchange_titlebody_best_voted_answer_jsonl/raspberrypi.stackexchange.com.json.gz": 24143, "stackexchange_titlebody_best_voted_answer_jsonl/retrocomputing.stackexchange.com.json.gz": 3907, "stackexchange_titlebody_best_voted_answer_jsonl/reverseengineering.stackexchange.com.json.gz": 5817, "stackexchange_titlebody_best_voted_answer_jsonl/robotics.stackexchange.com.json.gz": 4648, "stackexchange_titlebody_best_voted_answer_jsonl/rpg.stackexchange.com.json.gz": 40435, "stackexchange_titlebody_best_voted_answer_jsonl/ru.stackoverflow.com.json.gz": 253289, "stackexchange_titlebody_best_voted_answer_jsonl/rus.stackexchange.com.json.gz": 16528, "stackexchange_titlebody_best_voted_answer_jsonl/russian.stackexchange.com.json.gz": 3937, "stackexchange_titlebody_best_voted_answer_jsonl/salesforce.stackexchange.com.json.gz": 87272, "stackexchange_titlebody_best_voted_answer_jsonl/scicomp.stackexchange.com.json.gz": 7036, "stackexchange_titlebody_best_voted_answer_jsonl/scifi.stackexchange.com.json.gz": 54805, "stackexchange_titlebody_best_voted_answer_jsonl/sharepoint.stackexchange.com.json.gz": 80420, "stackexchange_titlebody_best_voted_answer_jsonl/sitecore.stackexchange.com.json.gz": 7838, "stackexchange_titlebody_best_voted_answer_jsonl/skeptics.stackexchange.com.json.gz": 8145, "stackexchange_titlebody_best_voted_answer_jsonl/softwareengineering.stackexchange.com.json.gz": 51326, "stackexchange_titlebody_best_voted_answer_jsonl/softwarerecs.stackexchange.com.json.gz": 11761, "stackexchange_titlebody_best_voted_answer_jsonl/sound.stackexchange.com.json.gz": 8303, "stackexchange_titlebody_best_voted_answer_jsonl/space.stackexchange.com.json.gz": 12893, "stackexchange_titlebody_best_voted_answer_jsonl/spanish.stackexchange.com.json.gz": 7675, "stackexchange_titlebody_best_voted_answer_jsonl/sports.stackexchange.com.json.gz": 4707, "stackexchange_titlebody_best_voted_answer_jsonl/sqa.stackexchange.com.json.gz": 9256, "stackexchange_titlebody_best_voted_answer_jsonl/stackapps.com.json.gz": 1518, "stackexchange_titlebody_best_voted_answer_jsonl/stats.stackexchange.com.json.gz": 115679, "stackexchange_titlebody_best_voted_answer_jsonl/stellar.stackexchange.com.json.gz": 1078, "stackexchange_titlebody_best_voted_answer_jsonl/superuser.com.json.gz": 352610, "stackexchange_titlebody_best_voted_answer_jsonl/sustainability.stackexchange.com.json.gz": 1674, "stackexchange_titlebody_best_voted_answer_jsonl/tex.stackexchange.com.json.gz": 171628, "stackexchange_titlebody_best_voted_answer_jsonl/tezos.stackexchange.com.json.gz": 1169, "stackexchange_titlebody_best_voted_answer_jsonl/tor.stackexchange.com.json.gz": 4167, "stackexchange_titlebody_best_voted_answer_jsonl/travel.stackexchange.com.json.gz": 36533, "stackexchange_titlebody_best_voted_answer_jsonl/tridion.stackexchange.com.json.gz": 5907, "stackexchange_titlebody_best_voted_answer_jsonl/ukrainian.stackexchange.com.json.gz": 1767, "stackexchange_titlebody_best_voted_answer_jsonl/unix.stackexchange.com.json.gz": 155414, "stackexchange_titlebody_best_voted_answer_jsonl/ux.stackexchange.com.json.gz": 28901, "stackexchange_titlebody_best_voted_answer_jsonl/vegetarianism.stackexchange.com.json.gz": 585, "stackexchange_titlebody_best_voted_answer_jsonl/vi.stackexchange.com.json.gz": 9000, "stackexchange_titlebody_best_voted_answer_jsonl/webapps.stackexchange.com.json.gz": 24867, "stackexchange_titlebody_best_voted_answer_jsonl/webmasters.stackexchange.com.json.gz": 30370, "stackexchange_titlebody_best_voted_answer_jsonl/windowsphone.stackexchange.com.json.gz": 2807, "stackexchange_titlebody_best_voted_answer_jsonl/woodworking.stackexchange.com.json.gz": 2955, "stackexchange_titlebody_best_voted_answer_jsonl/wordpress.stackexchange.com.json.gz": 83621, "stackexchange_titlebody_best_voted_answer_jsonl/workplace.stackexchange.com.json.gz": 24012, "stackexchange_titlebody_best_voted_answer_jsonl/worldbuilding.stackexchange.com.json.gz": 26210, "stackexchange_titlebody_best_voted_answer_jsonl/writers.stackexchange.com.json.gz": 9867, "wikihow.json.gz": 128542, "xsum.json.gz": 226711, "yahoo_answers_question_answer.json.gz": 681164, "yahoo_answers_title_answer.json.gz": 1198260, "yahoo_answers_title_question.json.gz": 659896}
flash_attn_triton.py ADDED
@@ -0,0 +1,1112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Triton implementation of Flash Attention.
5
+
6
+ # Copyright (c) 2022, Tri Dao.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ *Experimental* implementation of FlashAttention in Triton.
21
+ We use the FlashAttention implementation from Phil Tillet a starting point.
22
+ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
23
+
24
+ Changes:
25
+ - Implement both causal and non-causal attention.
26
+ - Implement both self-attention and cross-attention.
27
+ - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
28
+ - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
29
+ - Support attention bias.
30
+ - Speed up the forward pass a bit, and only store the LSE instead of m and l.
31
+ - Make the backward for d=128 much faster by reducing register spilling.
32
+ - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
33
+ small batch size * nheads.
34
+
35
+ Caution:
36
+ - If you plan to use headdim other than 64 and 128, you should test for race conditions
37
+ (due to the Triton compiler), as done in tests/test_flash_attn.py
38
+ "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
39
+ for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
40
+ that there are none left for other head dimensions.
41
+ Differences between this Triton version and the CUDA version:
42
+ - Triton version doesn't support dropout.
43
+ - Triton forward is generally faster than CUDA forward.
44
+ - Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
45
+ It is slightly slower when headdim=128 and batch * nheads is large.
46
+ - Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
47
+ """
48
+
49
+ import math
50
+
51
+ import torch
52
+ import triton # type: ignore (reportMissingImports)
53
+ import triton.language as tl # type: ignore (reportMissingImports)
54
+ from einops import repeat
55
+
56
+
57
+ @triton.autotune(
58
+ configs=[
59
+ triton.Config({
60
+ 'BLOCK_M': 128,
61
+ 'BLOCK_N': 128
62
+ },
63
+ num_warps=8,
64
+ num_stages=1),
65
+ # This config has a race condition when EVEN_M == False, disabling it for now.
66
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
67
+ ],
68
+ key=[
69
+ 'CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL',
70
+ 'BLOCK_HEADDIM'
71
+ ])
72
+ @triton.heuristics({
73
+ 'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0,
74
+ 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0,
75
+ 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM'],
76
+ })
77
+ @triton.jit
78
+ def _fwd_kernel(
79
+ Q,
80
+ K,
81
+ V,
82
+ Bias,
83
+ Out,
84
+ Lse,
85
+ TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
86
+ softmax_scale,
87
+ stride_qb,
88
+ stride_qh,
89
+ stride_qm,
90
+ stride_kb,
91
+ stride_kh,
92
+ stride_kn,
93
+ stride_vb,
94
+ stride_vh,
95
+ stride_vn,
96
+ stride_bb,
97
+ stride_bh,
98
+ stride_bm,
99
+ stride_ob,
100
+ stride_oh,
101
+ stride_om,
102
+ nheads,
103
+ seqlen_q,
104
+ seqlen_k,
105
+ seqlen_q_rounded,
106
+ headdim,
107
+ CACHE_KEY_SEQLEN_Q,
108
+ CACHE_KEY_SEQLEN_K,
109
+ BIAS_TYPE: tl.constexpr,
110
+ IS_CAUSAL: tl.constexpr,
111
+ BLOCK_HEADDIM: tl.constexpr,
112
+ EVEN_M: tl.constexpr,
113
+ EVEN_N: tl.constexpr,
114
+ EVEN_HEADDIM: tl.constexpr,
115
+ BLOCK_M: tl.constexpr,
116
+ BLOCK_N: tl.constexpr,
117
+ ):
118
+ start_m = tl.program_id(0)
119
+ off_hb = tl.program_id(1)
120
+ off_b = off_hb // nheads
121
+ off_h = off_hb % nheads
122
+ # off_b = tl.program_id(1)
123
+ # off_h = tl.program_id(2)
124
+ # off_hb = off_b * nheads + off_h
125
+ # initialize offsets
126
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
127
+ offs_n = tl.arange(0, BLOCK_N)
128
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
129
+ # Initialize pointers to Q, K, V
130
+ # Adding parenthesis around indexing might use int32 math instead of int64 math?
131
+ # https://github.com/openai/triton/issues/741
132
+ # I'm seeing a tiny bit of difference (5-7us)
133
+ q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (
134
+ offs_m[:, None] * stride_qm + offs_d[None, :])
135
+ k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (
136
+ offs_n[:, None] * stride_kn + offs_d[None, :])
137
+ v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (
138
+ offs_n[:, None] * stride_vn + offs_d[None, :])
139
+ if BIAS_TYPE == 'vector':
140
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
141
+ elif BIAS_TYPE == 'matrix':
142
+ b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (
143
+ offs_m[:, None] * stride_bm + offs_n[None, :])
144
+ else:
145
+ raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
146
+ # initialize pointer to m and l
147
+ t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
148
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
149
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
150
+ acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
151
+ # load q: it will stay in SRAM throughout
152
+ # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
153
+ # tl.load(q_ptrs), we get the wrong output!
154
+ if EVEN_M & EVEN_N:
155
+ if EVEN_HEADDIM:
156
+ q = tl.load(q_ptrs)
157
+ else:
158
+ q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
159
+ else:
160
+ if EVEN_HEADDIM:
161
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
162
+ else:
163
+ q = tl.load(q_ptrs,
164
+ mask=(offs_m[:, None] < seqlen_q) &
165
+ (offs_d[None, :] < headdim),
166
+ other=0.0)
167
+ # loop over k, v and update accumulator
168
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum(
169
+ (start_m + 1) * BLOCK_M, seqlen_k)
170
+ for start_n in range(0, end_n, BLOCK_N):
171
+ start_n = tl.multiple_of(start_n, BLOCK_N)
172
+ # -- compute qk ----
173
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
174
+ if EVEN_HEADDIM:
175
+ k = tl.load(k_ptrs + start_n * stride_kn)
176
+ else:
177
+ k = tl.load(k_ptrs + start_n * stride_kn,
178
+ mask=offs_d[None, :] < headdim,
179
+ other=0.0)
180
+ else:
181
+ if EVEN_HEADDIM:
182
+ k = tl.load(k_ptrs + start_n * stride_kn,
183
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
184
+ other=0.0)
185
+ else:
186
+ k = tl.load(k_ptrs + start_n * stride_kn,
187
+ mask=((start_n + offs_n)[:, None] < seqlen_k) &
188
+ (offs_d[None, :] < headdim),
189
+ other=0.0)
190
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
191
+ qk += tl.dot(q, k, trans_b=True)
192
+ # Trying to combine the two masks seem to make the result wrong
193
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
194
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
195
+ float('-inf'))
196
+ if IS_CAUSAL:
197
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0,
198
+ float('-inf'))
199
+ if BIAS_TYPE != 'none':
200
+ if BIAS_TYPE == 'vector':
201
+ if EVEN_N:
202
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
203
+ else:
204
+ bias = tl.load(b_ptrs + start_n,
205
+ mask=(start_n + offs_n) < seqlen_k,
206
+ other=0.0).to(tl.float32)
207
+ bias = bias[None, :]
208
+ elif BIAS_TYPE == 'matrix':
209
+ if EVEN_M & EVEN_N:
210
+ bias = tl.load(b_ptrs + start_n).to(tl.float32)
211
+ else:
212
+ bias = tl.load(b_ptrs + start_n,
213
+ mask=(offs_m[:, None] < seqlen_q) &
214
+ ((start_n + offs_n)[None, :] < seqlen_k),
215
+ other=0.0).to(tl.float32)
216
+ else:
217
+ raise ValueError(
218
+ "BIAS_TYPE must be one of {'vector', 'matrix'}")
219
+ # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
220
+ # can then fuse the mult and add into an fma instruction. But if we have bias we need to
221
+ # to multiply with softmax_scale here.
222
+ qk = qk * softmax_scale + bias
223
+ m_ij = tl.maximum(tl.max(qk, 1), lse_i)
224
+ p = tl.exp(qk - m_ij[:, None])
225
+ else:
226
+ m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
227
+ p = tl.exp(qk * softmax_scale - m_ij[:, None])
228
+ l_ij = tl.sum(p, 1)
229
+
230
+ # scale acc_o
231
+ acc_o_scale = tl.exp(m_i - m_ij)
232
+
233
+ # # -- update output accumulator --
234
+ # BUG: have to store and immediately load
235
+ tl.store(t_ptrs, acc_o_scale)
236
+ acc_o_scale = tl.load(t_ptrs)
237
+ acc_o = acc_o * acc_o_scale[:, None]
238
+ # update acc_o
239
+ if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
240
+ if EVEN_HEADDIM:
241
+ v = tl.load(v_ptrs + start_n * stride_vn)
242
+ else:
243
+ v = tl.load(v_ptrs + start_n * stride_vn,
244
+ mask=offs_d[None, :] < headdim,
245
+ other=0.0)
246
+ else:
247
+ if EVEN_HEADDIM:
248
+ v = tl.load(v_ptrs + start_n * stride_vn,
249
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
250
+ other=0.0)
251
+ else:
252
+ v = tl.load(v_ptrs + start_n * stride_vn,
253
+ mask=((start_n + offs_n)[:, None] < seqlen_k) &
254
+ (offs_d[None, :] < headdim),
255
+ other=0.0)
256
+ p = p.to(v.dtype)
257
+ acc_o += tl.dot(p, v)
258
+
259
+ # -- update statistics
260
+ m_i = m_ij
261
+ l_i_new = tl.exp(lse_i - m_ij) + l_ij
262
+ lse_i = m_ij + tl.log(l_i_new)
263
+
264
+ o_scale = tl.exp(m_i - lse_i)
265
+ # BUG: have to store and immediately load
266
+ tl.store(t_ptrs, o_scale)
267
+ o_scale = tl.load(t_ptrs)
268
+ acc_o = acc_o * o_scale[:, None]
269
+ # rematerialize offsets to save registers
270
+ start_m = tl.program_id(0)
271
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
272
+ # write back l and m
273
+ lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
274
+ tl.store(lse_ptrs, lse_i)
275
+ # initialize pointers to output
276
+ offs_n = tl.arange(0, BLOCK_HEADDIM)
277
+ out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (
278
+ offs_m[:, None] * stride_om + offs_n[None, :])
279
+ if EVEN_M:
280
+ if EVEN_HEADDIM:
281
+ tl.store(out_ptrs, acc_o)
282
+ else:
283
+ tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
284
+ else:
285
+ if EVEN_HEADDIM:
286
+ tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
287
+ else:
288
+ tl.store(out_ptrs,
289
+ acc_o,
290
+ mask=(offs_m[:, None] < seqlen_q) &
291
+ (offs_d[None, :] < headdim))
292
+
293
+
294
+ @triton.jit
295
+ def _bwd_preprocess_do_o_dot(
296
+ Out,
297
+ DO,
298
+ Delta,
299
+ stride_ob,
300
+ stride_oh,
301
+ stride_om,
302
+ stride_dob,
303
+ stride_doh,
304
+ stride_dom,
305
+ nheads,
306
+ seqlen_q,
307
+ seqlen_q_rounded,
308
+ headdim,
309
+ BLOCK_M: tl.constexpr,
310
+ BLOCK_HEADDIM: tl.constexpr,
311
+ ):
312
+ start_m = tl.program_id(0)
313
+ off_hb = tl.program_id(1)
314
+ off_b = off_hb // nheads
315
+ off_h = off_hb % nheads
316
+ # initialize offsets
317
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
318
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
319
+ # load
320
+ o = tl.load(Out + off_b * stride_ob + off_h * stride_oh +
321
+ offs_m[:, None] * stride_om + offs_d[None, :],
322
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
323
+ other=0.0).to(tl.float32)
324
+ do = tl.load(DO + off_b * stride_dob + off_h * stride_doh +
325
+ offs_m[:, None] * stride_dom + offs_d[None, :],
326
+ mask=(offs_m[:, None] < seqlen_q) &
327
+ (offs_d[None, :] < headdim),
328
+ other=0.0).to(tl.float32)
329
+ delta = tl.sum(o * do, axis=1)
330
+ # write-back
331
+ tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
332
+
333
+
334
+ @triton.jit
335
+ def _bwd_kernel_one_col_block(
336
+ start_n,
337
+ Q,
338
+ K,
339
+ V,
340
+ Bias,
341
+ DO,
342
+ DQ,
343
+ DK,
344
+ DV,
345
+ LSE,
346
+ D,
347
+ softmax_scale,
348
+ stride_qm,
349
+ stride_kn,
350
+ stride_vn,
351
+ stride_bm,
352
+ stride_dom,
353
+ stride_dqm,
354
+ stride_dkn,
355
+ stride_dvn,
356
+ seqlen_q,
357
+ seqlen_k,
358
+ headdim,
359
+ ATOMIC_ADD: tl.constexpr,
360
+ BIAS_TYPE: tl.constexpr,
361
+ IS_CAUSAL: tl.constexpr,
362
+ BLOCK_HEADDIM: tl.constexpr,
363
+ EVEN_M: tl.constexpr,
364
+ EVEN_N: tl.constexpr,
365
+ EVEN_HEADDIM: tl.constexpr,
366
+ BLOCK_M: tl.constexpr,
367
+ BLOCK_N: tl.constexpr,
368
+ ):
369
+ # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
370
+ begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
371
+ # initialize row/col offsets
372
+ offs_qm = begin_m + tl.arange(0, BLOCK_M)
373
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
374
+ offs_m = tl.arange(0, BLOCK_M)
375
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
376
+ # initialize pointers to value-like data
377
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
378
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
379
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
380
+ do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
381
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
382
+ if BIAS_TYPE == 'vector':
383
+ b_ptrs = Bias + offs_n
384
+ elif BIAS_TYPE == 'matrix':
385
+ b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
386
+ else:
387
+ raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
388
+ # initialize dv and dk
389
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
390
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
391
+ # k and v stay in SRAM throughout
392
+ # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
393
+ # if we just call tl.load(k_ptrs), we get the wrong output!
394
+ if EVEN_N & EVEN_M:
395
+ if EVEN_HEADDIM:
396
+ k = tl.load(k_ptrs)
397
+ v = tl.load(v_ptrs)
398
+ else:
399
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
400
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
401
+ else:
402
+ if EVEN_HEADDIM:
403
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
404
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
405
+ else:
406
+ k = tl.load(k_ptrs,
407
+ mask=(offs_n[:, None] < seqlen_k) &
408
+ (offs_d[None, :] < headdim),
409
+ other=0.0)
410
+ v = tl.load(v_ptrs,
411
+ mask=(offs_n[:, None] < seqlen_k) &
412
+ (offs_d[None, :] < headdim),
413
+ other=0.0)
414
+ # loop over rows
415
+ num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
416
+ for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
417
+ start_m = tl.multiple_of(start_m, BLOCK_M)
418
+ offs_m_curr = start_m + offs_m
419
+ # load q, k, v, do on-chip
420
+ # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
421
+ if EVEN_M & EVEN_HEADDIM:
422
+ q = tl.load(q_ptrs)
423
+ else:
424
+ if EVEN_HEADDIM:
425
+ q = tl.load(q_ptrs,
426
+ mask=offs_m_curr[:, None] < seqlen_q,
427
+ other=0.0)
428
+ else:
429
+ q = tl.load(q_ptrs,
430
+ mask=(offs_m_curr[:, None] < seqlen_q) &
431
+ (offs_d[None, :] < headdim),
432
+ other=0.0)
433
+ # recompute p = softmax(qk, dim=-1).T
434
+ qk = tl.dot(q, k, trans_b=True)
435
+ # Trying to combine the two masks seem to make the result wrong
436
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
437
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
438
+ if IS_CAUSAL:
439
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk,
440
+ float('-inf'))
441
+ if BIAS_TYPE != 'none':
442
+ if BIAS_TYPE == 'vector':
443
+ if EVEN_N:
444
+ bias = tl.load(b_ptrs).to(tl.float32)
445
+ else:
446
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k,
447
+ other=0.0).to(tl.float32)
448
+ bias = bias[None, :]
449
+ elif BIAS_TYPE == 'matrix':
450
+ if EVEN_M & EVEN_N:
451
+ bias = tl.load(b_ptrs).to(tl.float32)
452
+ else:
453
+ bias = tl.load(b_ptrs,
454
+ mask=(offs_m_curr[:, None] < seqlen_q) &
455
+ (offs_n[None, :] < seqlen_k),
456
+ other=0.0).to(tl.float32)
457
+ else:
458
+ raise ValueError(
459
+ "BIAS_TYPE must be one of {'vector', 'matrix'}")
460
+ qk = qk * softmax_scale + bias
461
+ # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
462
+ # Also wrong for headdim=64.
463
+ if not (EVEN_M & EVEN_HEADDIM):
464
+ tl.debug_barrier()
465
+ lse_i = tl.load(LSE + offs_m_curr)
466
+ if BIAS_TYPE == 'none':
467
+ p = tl.exp(qk * softmax_scale - lse_i[:, None])
468
+ else:
469
+ p = tl.exp(qk - lse_i[:, None])
470
+ # compute dv
471
+ # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
472
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
473
+ # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
474
+ # the output is correct.
475
+ if EVEN_M & EVEN_HEADDIM:
476
+ do = tl.load(do_ptrs)
477
+ else:
478
+ # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
479
+ do = tl.load(do_ptrs,
480
+ mask=(offs_m_curr[:, None] < seqlen_q) &
481
+ (offs_d[None, :] < headdim),
482
+ other=0.0)
483
+ # if EVEN_M:
484
+ # if EVEN_HEADDIM:
485
+ # do = tl.load(do_ptrs)
486
+ # else:
487
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
488
+ # else:
489
+ # if EVEN_HEADDIM:
490
+ # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
491
+ # else:
492
+ # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
493
+ # & (offs_d[None, :] < headdim), other=0.0)
494
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
495
+ # compute dp = dot(v, do)
496
+ # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
497
+ # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
498
+ # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
499
+ if not (EVEN_M & EVEN_HEADDIM):
500
+ tl.debug_barrier()
501
+ dp = tl.dot(do, v, trans_b=True)
502
+ # There's a race condition for headdim=48
503
+ if not EVEN_HEADDIM:
504
+ tl.debug_barrier()
505
+ # compute ds = p * (dp - delta[:, None])
506
+ # Putting the subtraction after the dp matmul (instead of before) is slightly faster
507
+ Di = tl.load(D + offs_m_curr)
508
+ # Converting ds to q.dtype here reduces register pressure and makes it much faster
509
+ # for BLOCK_HEADDIM=128
510
+ ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
511
+ # compute dk = dot(ds.T, q)
512
+ dk += tl.dot(ds, q, trans_a=True)
513
+ # compute dq
514
+ if not ATOMIC_ADD:
515
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
516
+ dq = tl.load(dq_ptrs, eviction_policy='evict_last')
517
+ dq += tl.dot(ds, k)
518
+ tl.store(dq_ptrs, dq, eviction_policy='evict_last')
519
+ else:
520
+ if EVEN_HEADDIM:
521
+ dq = tl.load(dq_ptrs,
522
+ mask=offs_m_curr[:, None] < seqlen_q,
523
+ other=0.0,
524
+ eviction_policy='evict_last')
525
+ dq += tl.dot(ds, k)
526
+ tl.store(dq_ptrs,
527
+ dq,
528
+ mask=offs_m_curr[:, None] < seqlen_q,
529
+ eviction_policy='evict_last')
530
+ else:
531
+ dq = tl.load(dq_ptrs,
532
+ mask=(offs_m_curr[:, None] < seqlen_q) &
533
+ (offs_d[None, :] < headdim),
534
+ other=0.0,
535
+ eviction_policy='evict_last')
536
+ dq += tl.dot(ds, k)
537
+ tl.store(dq_ptrs,
538
+ dq,
539
+ mask=(offs_m_curr[:, None] < seqlen_q) &
540
+ (offs_d[None, :] < headdim),
541
+ eviction_policy='evict_last')
542
+ else: # If we're parallelizing across the seqlen_k dimension
543
+ dq = tl.dot(ds, k)
544
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
545
+ tl.atomic_add(dq_ptrs, dq)
546
+ else:
547
+ if EVEN_HEADDIM:
548
+ tl.atomic_add(dq_ptrs,
549
+ dq,
550
+ mask=offs_m_curr[:, None] < seqlen_q)
551
+ else:
552
+ tl.atomic_add(dq_ptrs,
553
+ dq,
554
+ mask=(offs_m_curr[:, None] < seqlen_q) &
555
+ (offs_d[None, :] < headdim))
556
+ # increment pointers
557
+ dq_ptrs += BLOCK_M * stride_dqm
558
+ q_ptrs += BLOCK_M * stride_qm
559
+ do_ptrs += BLOCK_M * stride_dom
560
+ if BIAS_TYPE == 'matrix':
561
+ b_ptrs += BLOCK_M * stride_bm
562
+ # write-back
563
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
564
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
565
+ # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
566
+ # if we just call tl.store(dv_ptrs), there's a race condition
567
+ if EVEN_N & EVEN_M:
568
+ if EVEN_HEADDIM:
569
+ tl.store(dv_ptrs, dv)
570
+ tl.store(dk_ptrs, dk)
571
+ else:
572
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
573
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
574
+ else:
575
+ if EVEN_HEADDIM:
576
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
577
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
578
+ else:
579
+ tl.store(dv_ptrs,
580
+ dv,
581
+ mask=(offs_n[:, None] < seqlen_k) &
582
+ (offs_d[None, :] < headdim))
583
+ tl.store(dk_ptrs,
584
+ dk,
585
+ mask=(offs_n[:, None] < seqlen_k) &
586
+ (offs_d[None, :] < headdim))
587
+
588
+
589
+ def init_to_zero(name):
590
+ return lambda nargs: nargs[name].zero_()
591
+
592
+
593
+ @triton.autotune(
594
+ configs=[
595
+ triton.Config(
596
+ {
597
+ 'BLOCK_M': 128,
598
+ 'BLOCK_N': 128,
599
+ 'SEQUENCE_PARALLEL': False
600
+ },
601
+ num_warps=8,
602
+ num_stages=1,
603
+ pre_hook=init_to_zero('DQ')),
604
+ triton.Config(
605
+ {
606
+ 'BLOCK_M': 128,
607
+ 'BLOCK_N': 128,
608
+ 'SEQUENCE_PARALLEL': True
609
+ },
610
+ num_warps=8,
611
+ num_stages=1,
612
+ pre_hook=init_to_zero('DQ')),
613
+ # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
614
+ # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
615
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
616
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
617
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
618
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
619
+ ],
620
+ key=[
621
+ 'CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL',
622
+ 'BLOCK_HEADDIM'
623
+ ],
624
+ )
625
+ @triton.heuristics({
626
+ 'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0,
627
+ 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0,
628
+ 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM'],
629
+ })
630
+ @triton.jit
631
+ def _bwd_kernel(
632
+ Q,
633
+ K,
634
+ V,
635
+ Bias,
636
+ DO,
637
+ DQ,
638
+ DK,
639
+ DV,
640
+ LSE,
641
+ D,
642
+ softmax_scale,
643
+ stride_qb,
644
+ stride_qh,
645
+ stride_qm,
646
+ stride_kb,
647
+ stride_kh,
648
+ stride_kn,
649
+ stride_vb,
650
+ stride_vh,
651
+ stride_vn,
652
+ stride_bb,
653
+ stride_bh,
654
+ stride_bm,
655
+ stride_dob,
656
+ stride_doh,
657
+ stride_dom,
658
+ stride_dqb,
659
+ stride_dqh,
660
+ stride_dqm,
661
+ stride_dkb,
662
+ stride_dkh,
663
+ stride_dkn,
664
+ stride_dvb,
665
+ stride_dvh,
666
+ stride_dvn,
667
+ nheads,
668
+ seqlen_q,
669
+ seqlen_k,
670
+ seqlen_q_rounded,
671
+ headdim,
672
+ CACHE_KEY_SEQLEN_Q,
673
+ CACHE_KEY_SEQLEN_K,
674
+ BIAS_TYPE: tl.constexpr,
675
+ IS_CAUSAL: tl.constexpr,
676
+ BLOCK_HEADDIM: tl.constexpr,
677
+ SEQUENCE_PARALLEL: tl.constexpr,
678
+ EVEN_M: tl.constexpr,
679
+ EVEN_N: tl.constexpr,
680
+ EVEN_HEADDIM: tl.constexpr,
681
+ BLOCK_M: tl.constexpr,
682
+ BLOCK_N: tl.constexpr,
683
+ ):
684
+ off_hb = tl.program_id(1)
685
+ off_b = off_hb // nheads
686
+ off_h = off_hb % nheads
687
+ # offset pointers for batch/head
688
+ Q += off_b * stride_qb + off_h * stride_qh
689
+ K += off_b * stride_kb + off_h * stride_kh
690
+ V += off_b * stride_vb + off_h * stride_vh
691
+ DO += off_b * stride_dob + off_h * stride_doh
692
+ DQ += off_b * stride_dqb + off_h * stride_dqh
693
+ DK += off_b * stride_dkb + off_h * stride_dkh
694
+ DV += off_b * stride_dvb + off_h * stride_dvh
695
+ if BIAS_TYPE != 'none':
696
+ Bias += off_b * stride_bb + off_h * stride_bh
697
+ # pointer to row-wise quantities in value-like data
698
+ D += off_hb * seqlen_q_rounded
699
+ LSE += off_hb * seqlen_q_rounded
700
+ if not SEQUENCE_PARALLEL:
701
+ num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
702
+ for start_n in range(0, num_block_n):
703
+ _bwd_kernel_one_col_block(start_n,
704
+ Q,
705
+ K,
706
+ V,
707
+ Bias,
708
+ DO,
709
+ DQ,
710
+ DK,
711
+ DV,
712
+ LSE,
713
+ D,
714
+ softmax_scale,
715
+ stride_qm,
716
+ stride_kn,
717
+ stride_vn,
718
+ stride_bm,
719
+ stride_dom,
720
+ stride_dqm,
721
+ stride_dkn,
722
+ stride_dvn,
723
+ seqlen_q,
724
+ seqlen_k,
725
+ headdim,
726
+ ATOMIC_ADD=False,
727
+ BIAS_TYPE=BIAS_TYPE,
728
+ IS_CAUSAL=IS_CAUSAL,
729
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
730
+ EVEN_M=EVEN_M,
731
+ EVEN_N=EVEN_N,
732
+ EVEN_HEADDIM=EVEN_HEADDIM,
733
+ BLOCK_M=BLOCK_M,
734
+ BLOCK_N=BLOCK_N)
735
+ else:
736
+ start_n = tl.program_id(0)
737
+ _bwd_kernel_one_col_block(start_n,
738
+ Q,
739
+ K,
740
+ V,
741
+ Bias,
742
+ DO,
743
+ DQ,
744
+ DK,
745
+ DV,
746
+ LSE,
747
+ D,
748
+ softmax_scale,
749
+ stride_qm,
750
+ stride_kn,
751
+ stride_vn,
752
+ stride_bm,
753
+ stride_dom,
754
+ stride_dqm,
755
+ stride_dkn,
756
+ stride_dvn,
757
+ seqlen_q,
758
+ seqlen_k,
759
+ headdim,
760
+ ATOMIC_ADD=True,
761
+ BIAS_TYPE=BIAS_TYPE,
762
+ IS_CAUSAL=IS_CAUSAL,
763
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
764
+ EVEN_M=EVEN_M,
765
+ EVEN_N=EVEN_N,
766
+ EVEN_HEADDIM=EVEN_HEADDIM,
767
+ BLOCK_M=BLOCK_M,
768
+ BLOCK_N=BLOCK_N)
769
+
770
+
771
+ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
772
+ # shape constraints
773
+ batch, seqlen_q, nheads, d = q.shape
774
+ _, seqlen_k, _, _ = k.shape
775
+ assert k.shape == (batch, seqlen_k, nheads, d)
776
+ assert v.shape == (batch, seqlen_k, nheads, d)
777
+ assert d <= 128, 'FlashAttention only support head dimensions up to 128'
778
+ assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
779
+ assert q.dtype in [torch.float16,
780
+ torch.bfloat16], 'Only support fp16 and bf16'
781
+ assert q.is_cuda and k.is_cuda and v.is_cuda
782
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
783
+
784
+ has_bias = bias is not None
785
+ bias_type = 'none'
786
+ if has_bias:
787
+ assert bias.dtype in [q.dtype, torch.float]
788
+ assert bias.is_cuda
789
+ assert bias.dim() == 4
790
+ if bias.stride(-1) != 1:
791
+ bias = bias.contiguous()
792
+ if bias.shape[2:] == (1, seqlen_k):
793
+ bias_type = 'vector'
794
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
795
+ bias_type = 'matrix'
796
+ else:
797
+ raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
798
+ ' or (seqlen_q, seqlen_k)')
799
+ if bias.shape[:2] == (1, nheads):
800
+ bias = repeat(bias, '1 h ... -> b h ...', b=batch)
801
+ elif bias.shape[:2] == (batch, 1):
802
+ bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
803
+ elif bias.shape[:2] == (1, 1):
804
+ bias = repeat(bias, '1 h ... -> b h ...', b=batch)
805
+ bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
806
+ assert bias.shape[:2] == (
807
+ batch, nheads
808
+ ), f'First 2 dimensions of bias must be broadcastible to (batch, nheads) = ({batch, nheads}). Bias has shape: {bias.shape}'
809
+ assert bias is not None # for type checking
810
+ bias_strides = (bias.stride(0), bias.stride(1),
811
+ bias.stride(2)) if has_bias else (0, 0, 0)
812
+
813
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
814
+ lse = torch.empty((batch, nheads, seqlen_q_rounded),
815
+ device=q.device,
816
+ dtype=torch.float32)
817
+ tmp = torch.empty((batch, nheads, seqlen_q_rounded),
818
+ device=q.device,
819
+ dtype=torch.float32)
820
+ o = torch.empty_like(q)
821
+
822
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
823
+ # BLOCK = 128
824
+ # num_warps = 4 if d <= 64 else 8
825
+ grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
826
+ _fwd_kernel[grid]( # type: ignore
827
+ q,
828
+ k,
829
+ v,
830
+ bias,
831
+ o,
832
+ lse,
833
+ tmp,
834
+ softmax_scale,
835
+ q.stride(0),
836
+ q.stride(2),
837
+ q.stride(1),
838
+ k.stride(0),
839
+ k.stride(2),
840
+ k.stride(1),
841
+ v.stride(0),
842
+ v.stride(2),
843
+ v.stride(1),
844
+ *bias_strides,
845
+ o.stride(0),
846
+ o.stride(2),
847
+ o.stride(1),
848
+ nheads,
849
+ seqlen_q,
850
+ seqlen_k,
851
+ seqlen_q_rounded,
852
+ d,
853
+ seqlen_q // 32,
854
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
855
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
856
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
857
+ bias_type,
858
+ causal,
859
+ BLOCK_HEADDIM,
860
+ # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
861
+ # num_warps=num_warps,
862
+ # num_stages=1,
863
+ )
864
+ return o, lse, softmax_scale # softmax_scale could have been updated
865
+
866
+
867
+ def _flash_attn_backward(do,
868
+ q,
869
+ k,
870
+ v,
871
+ o,
872
+ lse,
873
+ dq,
874
+ dk,
875
+ dv,
876
+ bias=None,
877
+ causal=False,
878
+ softmax_scale=None):
879
+ # Make sure that the last dimension is contiguous
880
+ if do.stride(-1) != 1:
881
+ do = do.contiguous()
882
+ batch, seqlen_q, nheads, d = q.shape
883
+ _, seqlen_k, _, _ = k.shape
884
+ # assert d in {16, 32, 64, 128}
885
+ assert d <= 128
886
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
887
+ assert lse.shape == (batch, nheads, seqlen_q_rounded)
888
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
889
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
890
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
891
+ # dq_accum = torch.zeros_like(q, dtype=torch.float32)
892
+ dq_accum = torch.empty_like(q, dtype=torch.float32)
893
+ delta = torch.empty_like(lse)
894
+ # delta = torch.zeros_like(lse)
895
+
896
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
897
+ grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
898
+ _bwd_preprocess_do_o_dot[grid]( # type: ignore
899
+ o,
900
+ do,
901
+ delta,
902
+ o.stride(0),
903
+ o.stride(2),
904
+ o.stride(1),
905
+ do.stride(0),
906
+ do.stride(2),
907
+ do.stride(1),
908
+ nheads,
909
+ seqlen_q,
910
+ seqlen_q_rounded,
911
+ d,
912
+ BLOCK_M=128,
913
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
914
+ )
915
+
916
+ has_bias = bias is not None
917
+ bias_type = 'none'
918
+ if has_bias:
919
+ assert bias.dtype in [q.dtype, torch.float]
920
+ assert bias.is_cuda
921
+ assert bias.dim() == 4
922
+ assert bias.stride(-1) == 1
923
+ if bias.shape[2:] == (1, seqlen_k):
924
+ bias_type = 'vector'
925
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
926
+ bias_type = 'matrix'
927
+ else:
928
+ raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
929
+ ' or (seqlen_q, seqlen_k)')
930
+ if bias.shape[:2] == (1, nheads):
931
+ bias = repeat(bias, '1 h ... -> b h ...', b=batch)
932
+ elif bias.shape[:2] == (batch, 1):
933
+ bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
934
+ elif bias.shape[:2] == (1, 1):
935
+ bias = repeat(bias, '1 h ... -> b h ...', b=batch)
936
+ bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
937
+ assert bias.shape[:2] == (
938
+ batch, nheads
939
+ ), f'First 2 dimensions of bias must be broadcastible to (batch, nheads) = ({batch, nheads}). Bias has shape: {bias.shape}'
940
+ assert bias is not None # type checking
941
+ bias_strides = (bias.stride(0), bias.stride(1),
942
+ bias.stride(2)) if has_bias else (0, 0, 0)
943
+
944
+ # BLOCK_M = 128
945
+ # BLOCK_N = 64
946
+ # num_warps = 4
947
+ grid = lambda META: (triton.cdiv(seqlen_k, META['BLOCK_N'])
948
+ if META['SEQUENCE_PARALLEL'] else 1, batch * nheads)
949
+ _bwd_kernel[grid]( # type: ignore
950
+ q,
951
+ k,
952
+ v,
953
+ bias,
954
+ do,
955
+ dq_accum,
956
+ dk,
957
+ dv,
958
+ lse,
959
+ delta,
960
+ softmax_scale,
961
+ q.stride(0),
962
+ q.stride(2),
963
+ q.stride(1),
964
+ k.stride(0),
965
+ k.stride(2),
966
+ k.stride(1),
967
+ v.stride(0),
968
+ v.stride(2),
969
+ v.stride(1),
970
+ *bias_strides,
971
+ do.stride(0),
972
+ do.stride(2),
973
+ do.stride(1),
974
+ dq_accum.stride(0),
975
+ dq_accum.stride(2),
976
+ dq_accum.stride(1),
977
+ dk.stride(0),
978
+ dk.stride(2),
979
+ dk.stride(1),
980
+ dv.stride(0),
981
+ dv.stride(2),
982
+ dv.stride(1),
983
+ nheads,
984
+ seqlen_q,
985
+ seqlen_k,
986
+ seqlen_q_rounded,
987
+ d,
988
+ seqlen_q // 32,
989
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
990
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
991
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
992
+ bias_type,
993
+ causal,
994
+ BLOCK_HEADDIM,
995
+ # SEQUENCE_PARALLEL=False,
996
+ # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
997
+ # num_warps=num_warps,
998
+ # num_stages=1,
999
+ )
1000
+ dq.copy_(dq_accum)
1001
+
1002
+
1003
+ class _FlashAttnQKVPackedFunc(torch.autograd.Function):
1004
+
1005
+ @staticmethod
1006
+ def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
1007
+ """Forward pass for packed FlashAttention.
1008
+
1009
+ Args:
1010
+ ctx: autograd context
1011
+ qkv: (batch, seqlen, 3, nheads, headdim)
1012
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
1013
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
1014
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
1015
+ causal (bool): whether to incorporate causal attention masking
1016
+ softmax_scale (float, optional): scale factor for softmax
1017
+ """
1018
+ # Make sure that the last dimension is contiguous
1019
+ if qkv.stride(-1) != 1:
1020
+ qkv = qkv.contiguous()
1021
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1022
+ qkv[:, :, 0],
1023
+ qkv[:, :, 1],
1024
+ qkv[:, :, 2],
1025
+ bias=bias,
1026
+ causal=causal,
1027
+ softmax_scale=softmax_scale)
1028
+ ctx.save_for_backward(qkv, o, lse, bias)
1029
+ ctx.causal = causal
1030
+ return o
1031
+
1032
+ @staticmethod
1033
+ def backward(ctx, do):
1034
+ qkv, o, lse, bias = ctx.saved_tensors
1035
+ assert not ctx.needs_input_grad[
1036
+ 1], 'FlashAttention does not support bias gradient yet'
1037
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1038
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1039
+ with torch.inference_mode():
1040
+ dqkv = torch.empty_like(qkv)
1041
+ _flash_attn_backward(do,
1042
+ qkv[:, :, 0],
1043
+ qkv[:, :, 1],
1044
+ qkv[:, :, 2],
1045
+ o,
1046
+ lse,
1047
+ dqkv[:, :, 0],
1048
+ dqkv[:, :, 1],
1049
+ dqkv[:, :, 2],
1050
+ bias=bias,
1051
+ causal=ctx.causal,
1052
+ softmax_scale=ctx.softmax_scale)
1053
+ return dqkv, None, None, None
1054
+
1055
+
1056
+ flash_attn_qkvpacked_func = _FlashAttnQKVPackedFunc.apply
1057
+
1058
+
1059
+ class _FlashAttnFunc(torch.autograd.Function):
1060
+
1061
+ @staticmethod
1062
+ def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
1063
+ """Forward pass for FlashAttention.
1064
+
1065
+ Args:
1066
+ ctx: autograd context
1067
+ q: (batch_size, seqlen_q, nheads, headdim)
1068
+ k: (batch_size, seqlen_k, nheads, headdim)
1069
+ v: (batch_size, seqlen_k, nheads, headdim)
1070
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1071
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1072
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1073
+ causal (bool): whether to incorporate causal attention masking
1074
+ softmax_scale (float, optional): scale factor for softmax
1075
+ """
1076
+ # Make sure that the last dimension is contiguous
1077
+ q, k, v = [
1078
+ x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]
1079
+ ]
1080
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1081
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
1082
+ ctx.save_for_backward(q, k, v, o, lse, bias)
1083
+ ctx.causal = causal
1084
+ return o
1085
+
1086
+ @staticmethod
1087
+ def backward(ctx, do):
1088
+ q, k, v, o, lse, bias = ctx.saved_tensors
1089
+ assert not ctx.needs_input_grad[
1090
+ 3], 'FlashAttention does not support bias gradient yet'
1091
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1092
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1093
+ with torch.inference_mode():
1094
+ dq = torch.empty_like(q)
1095
+ dk = torch.empty_like(k)
1096
+ dv = torch.empty_like(v)
1097
+ _flash_attn_backward(do,
1098
+ q,
1099
+ k,
1100
+ v,
1101
+ o,
1102
+ lse,
1103
+ dq,
1104
+ dk,
1105
+ dv,
1106
+ bias=bias,
1107
+ causal=ctx.causal,
1108
+ softmax_scale=ctx.softmax_scale)
1109
+ return dq, dk, dv, None, None, None
1110
+
1111
+
1112
+ flash_attn_func = _FlashAttnFunc.apply
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
mteb_results/AmazonCounterfactualClassification.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "e8379541af4e31359cca9fbcf4b00f2671dba205",
3
+ "mteb_dataset_name": "AmazonCounterfactualClassification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "en": {
7
+ "accuracy": 0.697313432835821,
8
+ "accuracy_stderr": 0.04363113167902916,
9
+ "ap": 0.31618259511417734,
10
+ "ap_stderr": 0.0243939127481388,
11
+ "f1": 0.6330313825394228,
12
+ "f1_stderr": 0.03331211721747352,
13
+ "main_score": 0.697313432835821
14
+ },
15
+ "evaluation_time": 3.55
16
+ },
17
+ "validation": {
18
+ "en": {
19
+ "accuracy": 0.7074626865671642,
20
+ "accuracy_stderr": 0.03173177854547658,
21
+ "ap": 0.2916547890175021,
22
+ "ap_stderr": 0.028577509879931906,
23
+ "f1": 0.628207439570022,
24
+ "f1_stderr": 0.02728677964172927,
25
+ "main_score": 0.7074626865671642
26
+ },
27
+ "evaluation_time": 7.2
28
+ }
29
+ }
mteb_results/AmazonPolarityClassification.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "e2d317d38cd51312af73b3d32a06d1a08b442046",
3
+ "mteb_dataset_name": "AmazonPolarityClassification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "accuracy": 0.8689837499999999,
7
+ "accuracy_stderr": 0.010742354621427285,
8
+ "ap": 0.8239500885672127,
9
+ "ap_stderr": 0.013236818266475252,
10
+ "evaluation_time": 1082.95,
11
+ "f1": 0.8687317947399658,
12
+ "f1_stderr": 0.011035411217540664,
13
+ "main_score": 0.8689837499999999
14
+ }
15
+ }
mteb_results/AmazonReviewsClassification.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "1399c76144fd37290681b995c656ef9b2e06e26d",
3
+ "mteb_dataset_name": "AmazonReviewsClassification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "en": {
7
+ "accuracy": 0.44049999999999995,
8
+ "accuracy_stderr": 0.014423938435808711,
9
+ "f1": 0.4267624383248947,
10
+ "f1_stderr": 0.01351683620968048,
11
+ "main_score": 0.44049999999999995
12
+ },
13
+ "evaluation_time": 11.38
14
+ },
15
+ "validation": {
16
+ "en": {
17
+ "accuracy": 0.43798000000000004,
18
+ "accuracy_stderr": 0.012288352208494032,
19
+ "f1": 0.42483998553432956,
20
+ "f1_stderr": 0.015752944478543963,
21
+ "main_score": 0.43798000000000004
22
+ },
23
+ "evaluation_time": 13.69
24
+ }
25
+ }
mteb_results/ArguAna.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "ArguAna",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 52.53,
7
+ "map_at_1": 0.26174,
8
+ "map_at_10": 0.40976,
9
+ "map_at_100": 0.42067,
10
+ "map_at_1000": 0.42075,
11
+ "map_at_3": 0.35917,
12
+ "map_at_5": 0.38656,
13
+ "mrr_at_1": 0.26814,
14
+ "mrr_at_10": 0.41252,
15
+ "mrr_at_100": 0.42337,
16
+ "mrr_at_1000": 0.42345,
17
+ "mrr_at_3": 0.36226,
18
+ "mrr_at_5": 0.38914,
19
+ "ndcg_at_1": 0.26174,
20
+ "ndcg_at_10": 0.49819,
21
+ "ndcg_at_100": 0.54404,
22
+ "ndcg_at_1000": 0.5459,
23
+ "ndcg_at_3": 0.39231,
24
+ "ndcg_at_5": 0.44189,
25
+ "precision_at_1": 0.26174,
26
+ "precision_at_10": 0.07838,
27
+ "precision_at_100": 0.00982,
28
+ "precision_at_1000": 0.001,
29
+ "precision_at_3": 0.16287,
30
+ "precision_at_5": 0.12191,
31
+ "recall_at_1": 0.26174,
32
+ "recall_at_10": 0.78378,
33
+ "recall_at_100": 0.98222,
34
+ "recall_at_1000": 0.99644,
35
+ "recall_at_3": 0.48862,
36
+ "recall_at_5": 0.60953
37
+ }
38
+ }
mteb_results/ArxivClusteringP2P.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "a122ad7f3f0291bf49cc6f4d32aa80929df69d5d",
3
+ "mteb_dataset_name": "ArxivClusteringP2P",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 4034.06,
7
+ "v_measure": 0.4231689035788179,
8
+ "v_measure_std": 0.1399577095144373
9
+ }
10
+ }
mteb_results/ArxivClusteringS2S.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "f910caf1a6075f7329cdf8c1a6135696f37dbd53",
3
+ "mteb_dataset_name": "ArxivClusteringS2S",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 432.48,
7
+ "v_measure": 0.31280245136660983,
8
+ "v_measure_std": 0.14616358182910433
9
+ }
10
+ }
mteb_results/AskUbuntuDupQuestions.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "2000358ca161889fa9c082cb41daa8dcfb161a54",
3
+ "mteb_dataset_name": "AskUbuntuDupQuestions",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 4.26,
7
+ "map": 0.5879109720839415,
8
+ "mrr": 0.7179615705931495
9
+ }
10
+ }
mteb_results/BIOSSES.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "d3fb88f8f02e40887cd149695127462bbcf29b4a",
3
+ "mteb_dataset_name": "BIOSSES",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "cos_sim": {
7
+ "pearson": 0.7644918756608116,
8
+ "spearman": 0.7086607256286257
9
+ },
10
+ "euclidean": {
11
+ "pearson": 0.7412154678100815,
12
+ "spearman": 0.7086607256286257
13
+ },
14
+ "evaluation_time": 1.08,
15
+ "manhattan": {
16
+ "pearson": 0.7400786269644171,
17
+ "spearman": 0.7068353828321327
18
+ }
19
+ }
20
+ }
mteb_results/Banking77Classification.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "0fd18e25b25c072e09e0d92ab615fda904d66300",
3
+ "mteb_dataset_name": "Banking77Classification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "accuracy": 0.7540584415584415,
7
+ "accuracy_stderr": 0.007828985179390284,
8
+ "evaluation_time": 20.37,
9
+ "f1": 0.7429514617572676,
10
+ "f1_stderr": 0.00868929710762345,
11
+ "main_score": 0.7540584415584415
12
+ }
13
+ }
mteb_results/BiorxivClusteringP2P.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "65b79d1d13f80053f67aca9498d9402c2d9f1f40",
3
+ "mteb_dataset_name": "BiorxivClusteringP2P",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 547.26,
7
+ "v_measure": 0.3741860080664014,
8
+ "v_measure_std": 0.008407780040443218
9
+ }
10
+ }
mteb_results/BiorxivClusteringS2S.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "258694dd0231531bc1fd9de6ceb52a0853c6d908",
3
+ "mteb_dataset_name": "BiorxivClusteringS2S",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 53.57,
7
+ "v_measure": 0.29319217023090705,
8
+ "v_measure_std": 0.010219281239166302
9
+ }
10
+ }
mteb_results/CQADupstackEnglishRetrieval.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "CQADupstackEnglishRetrieval",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 113.5,
7
+ "map_at_1": 0.22528,
8
+ "map_at_10": 0.30751,
9
+ "map_at_100": 0.31855,
10
+ "map_at_1000": 0.31972,
11
+ "map_at_3": 0.28465,
12
+ "map_at_5": 0.29738,
13
+ "mrr_at_1": 0.28662,
14
+ "mrr_at_10": 0.35912,
15
+ "mrr_at_100": 0.36726,
16
+ "mrr_at_1000": 0.36777,
17
+ "mrr_at_3": 0.34013,
18
+ "mrr_at_5": 0.35156,
19
+ "ndcg_at_1": 0.28662,
20
+ "ndcg_at_10": 0.35452,
21
+ "ndcg_at_100": 0.401,
22
+ "ndcg_at_1000": 0.42323,
23
+ "ndcg_at_3": 0.32112,
24
+ "ndcg_at_5": 0.33638,
25
+ "precision_at_1": 0.28662,
26
+ "precision_at_10": 0.06688,
27
+ "precision_at_100": 0.0113,
28
+ "precision_at_1000": 0.0016,
29
+ "precision_at_3": 0.15563,
30
+ "precision_at_5": 0.11019,
31
+ "recall_at_1": 0.22528,
32
+ "recall_at_10": 0.43748,
33
+ "recall_at_100": 0.64235,
34
+ "recall_at_1000": 0.78609,
35
+ "recall_at_3": 0.33937,
36
+ "recall_at_5": 0.38234
37
+ }
38
+ }
mteb_results/ClimateFEVER.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "ClimateFEVER",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 8671.85,
7
+ "map_at_1": 0.09468,
8
+ "map_at_10": 0.16029,
9
+ "map_at_100": 0.17693,
10
+ "map_at_1000": 0.17886,
11
+ "map_at_3": 0.1315,
12
+ "map_at_5": 0.14568,
13
+ "mrr_at_1": 0.21173,
14
+ "mrr_at_10": 0.31028,
15
+ "mrr_at_100": 0.32061,
16
+ "mrr_at_1000": 0.32119,
17
+ "mrr_at_3": 0.27535,
18
+ "mrr_at_5": 0.29431,
19
+ "ndcg_at_1": 0.21173,
20
+ "ndcg_at_10": 0.23224,
21
+ "ndcg_at_100": 0.30225,
22
+ "ndcg_at_1000": 0.33961,
23
+ "ndcg_at_3": 0.18174,
24
+ "ndcg_at_5": 0.19897,
25
+ "precision_at_1": 0.21173,
26
+ "precision_at_10": 0.07472,
27
+ "precision_at_100": 0.01501,
28
+ "precision_at_1000": 0.00219,
29
+ "precision_at_3": 0.13312,
30
+ "precision_at_5": 0.10619,
31
+ "recall_at_1": 0.09468,
32
+ "recall_at_10": 0.28823,
33
+ "recall_at_100": 0.53265,
34
+ "recall_at_1000": 0.74536,
35
+ "recall_at_3": 0.16672,
36
+ "recall_at_5": 0.21302
37
+ }
38
+ }
mteb_results/DBPedia.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "DBPedia",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 4445.99,
7
+ "map_at_1": 0.06343,
8
+ "map_at_10": 0.12717,
9
+ "map_at_100": 0.1648,
10
+ "map_at_1000": 0.17381,
11
+ "map_at_3": 0.09569,
12
+ "map_at_5": 0.11125,
13
+ "mrr_at_1": 0.4875,
14
+ "mrr_at_10": 0.58425,
15
+ "mrr_at_100": 0.59075,
16
+ "mrr_at_1000": 0.59095,
17
+ "mrr_at_3": 0.56292,
18
+ "mrr_at_5": 0.57679,
19
+ "ndcg_at_1": 0.37875,
20
+ "ndcg_at_10": 0.2777,
21
+ "ndcg_at_100": 0.30289,
22
+ "ndcg_at_1000": 0.36188,
23
+ "ndcg_at_3": 0.31386,
24
+ "ndcg_at_5": 0.29923,
25
+ "precision_at_1": 0.4875,
26
+ "precision_at_10": 0.22375,
27
+ "precision_at_100": 0.06342,
28
+ "precision_at_1000": 0.01449,
29
+ "precision_at_3": 0.355,
30
+ "precision_at_5": 0.3055,
31
+ "recall_at_1": 0.06343,
32
+ "recall_at_10": 0.16936,
33
+ "recall_at_100": 0.35956,
34
+ "recall_at_1000": 0.55787,
35
+ "recall_at_3": 0.10771,
36
+ "recall_at_5": 0.1367
37
+ }
38
+ }
mteb_results/EmotionClassification.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "4f58c6b202a23cf9a4da393831edf4f9183cad37",
3
+ "mteb_dataset_name": "EmotionClassification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "accuracy": 0.4199,
7
+ "accuracy_stderr": 0.02234367919569201,
8
+ "evaluation_time": 3.37,
9
+ "f1": 0.3682340217456495,
10
+ "f1_stderr": 0.021776128234136445,
11
+ "main_score": 0.4199
12
+ },
13
+ "validation": {
14
+ "accuracy": 0.41864999999999997,
15
+ "accuracy_stderr": 0.022959801828413062,
16
+ "evaluation_time": 3.29,
17
+ "f1": 0.3748604511300154,
18
+ "f1_stderr": 0.02042335727335004,
19
+ "main_score": 0.41864999999999997
20
+ }
21
+ }
mteb_results/FEVER.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "FEVER",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 16698.45,
7
+ "map_at_1": 0.40088,
8
+ "map_at_10": 0.52692,
9
+ "map_at_100": 0.53296,
10
+ "map_at_1000": 0.53325,
11
+ "map_at_3": 0.49905,
12
+ "map_at_5": 0.51617,
13
+ "mrr_at_1": 0.43009,
14
+ "mrr_at_10": 0.56203,
15
+ "mrr_at_100": 0.5675,
16
+ "mrr_at_1000": 0.56769,
17
+ "mrr_at_3": 0.534,
18
+ "mrr_at_5": 0.55163,
19
+ "ndcg_at_1": 0.43009,
20
+ "ndcg_at_10": 0.5939,
21
+ "ndcg_at_100": 0.6213,
22
+ "ndcg_at_1000": 0.62793,
23
+ "ndcg_at_3": 0.53878,
24
+ "ndcg_at_5": 0.56887,
25
+ "precision_at_1": 0.43009,
26
+ "precision_at_10": 0.08366,
27
+ "precision_at_100": 0.00983,
28
+ "precision_at_1000": 0.00105,
29
+ "precision_at_3": 0.22377,
30
+ "precision_at_5": 0.15035,
31
+ "recall_at_1": 0.40088,
32
+ "recall_at_10": 0.76687,
33
+ "recall_at_100": 0.8891,
34
+ "recall_at_1000": 0.93782,
35
+ "recall_at_3": 0.6181,
36
+ "recall_at_5": 0.69131
37
+ }
38
+ }
mteb_results/FiQA2018.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "FiQA2018",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 138.39,
7
+ "map_at_1": 0.10817,
8
+ "map_at_10": 0.189,
9
+ "map_at_100": 0.20448,
10
+ "map_at_1000": 0.20661,
11
+ "map_at_3": 0.15979,
12
+ "map_at_5": 0.17415,
13
+ "mrr_at_1": 0.23148,
14
+ "mrr_at_10": 0.31208,
15
+ "mrr_at_100": 0.32167,
16
+ "mrr_at_1000": 0.32242,
17
+ "mrr_at_3": 0.28498,
18
+ "mrr_at_5": 0.29964,
19
+ "ndcg_at_1": 0.23148,
20
+ "ndcg_at_10": 0.25326,
21
+ "ndcg_at_100": 0.31927,
22
+ "ndcg_at_1000": 0.36081,
23
+ "ndcg_at_3": 0.21647,
24
+ "ndcg_at_5": 0.22763,
25
+ "precision_at_1": 0.23148,
26
+ "precision_at_10": 0.07546,
27
+ "precision_at_100": 0.01415,
28
+ "precision_at_1000": 0.00216,
29
+ "precision_at_3": 0.14969,
30
+ "precision_at_5": 0.11327,
31
+ "recall_at_1": 0.10817,
32
+ "recall_at_10": 0.32164,
33
+ "recall_at_100": 0.57655,
34
+ "recall_at_1000": 0.82797,
35
+ "recall_at_3": 0.19709,
36
+ "recall_at_5": 0.24333
37
+ }
38
+ }
mteb_results/HotpotQA.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "HotpotQA",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 5192.2,
7
+ "map_at_1": 0.25381,
8
+ "map_at_10": 0.3314,
9
+ "map_at_100": 0.33948,
10
+ "map_at_1000": 0.34028,
11
+ "map_at_3": 0.3102,
12
+ "map_at_5": 0.3223,
13
+ "mrr_at_1": 0.50763,
14
+ "mrr_at_10": 0.57899,
15
+ "mrr_at_100": 0.58426,
16
+ "mrr_at_1000": 0.58457,
17
+ "mrr_at_3": 0.56093,
18
+ "mrr_at_5": 0.57116,
19
+ "ndcg_at_1": 0.50763,
20
+ "ndcg_at_10": 0.41656,
21
+ "ndcg_at_100": 0.45079,
22
+ "ndcg_at_1000": 0.46917,
23
+ "ndcg_at_3": 0.37834,
24
+ "ndcg_at_5": 0.39732,
25
+ "precision_at_1": 0.50763,
26
+ "precision_at_10": 0.08648,
27
+ "precision_at_100": 0.01135,
28
+ "precision_at_1000": 0.00138,
29
+ "precision_at_3": 0.23106,
30
+ "precision_at_5": 0.15363,
31
+ "recall_at_1": 0.25381,
32
+ "recall_at_10": 0.43241,
33
+ "recall_at_100": 0.56745,
34
+ "recall_at_1000": 0.69048,
35
+ "recall_at_3": 0.34659,
36
+ "recall_at_5": 0.38406
37
+ }
38
+ }
mteb_results/ImdbClassification.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "3d86128a09e091d6018b6d26cad27f2739fc2db7",
3
+ "mteb_dataset_name": "ImdbClassification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "accuracy": 0.79544,
7
+ "accuracy_stderr": 0.022193916283522398,
8
+ "ap": 0.7382920133396664,
9
+ "ap_stderr": 0.029776228173533717,
10
+ "evaluation_time": 205.12,
11
+ "f1": 0.7951048124883265,
12
+ "f1_stderr": 0.02219958939576688,
13
+ "main_score": 0.79544
14
+ }
15
+ }
mteb_results/MSMARCO.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "dev": {
4
+ "evaluation_time": 19626.59,
5
+ "map_at_1": 0.11174,
6
+ "map_at_10": 0.19452,
7
+ "map_at_100": 0.20612,
8
+ "map_at_1000": 0.20703,
9
+ "map_at_3": 0.16444,
10
+ "map_at_5": 0.18083,
11
+ "mrr_at_1": 0.11447,
12
+ "mrr_at_10": 0.19808,
13
+ "mrr_at_100": 0.20958,
14
+ "mrr_at_1000": 0.21042,
15
+ "mrr_at_3": 0.16791,
16
+ "mrr_at_5": 0.18459,
17
+ "ndcg_at_1": 0.11447,
18
+ "ndcg_at_10": 0.24556,
19
+ "ndcg_at_100": 0.30638,
20
+ "ndcg_at_1000": 0.3314,
21
+ "ndcg_at_3": 0.18325,
22
+ "ndcg_at_5": 0.21278,
23
+ "precision_at_1": 0.11447,
24
+ "precision_at_10": 0.04215,
25
+ "precision_at_100": 0.00732,
26
+ "precision_at_1000": 0.00095,
27
+ "precision_at_3": 0.08052,
28
+ "precision_at_5": 0.06318,
29
+ "recall_at_1": 0.11174,
30
+ "recall_at_10": 0.40543,
31
+ "recall_at_100": 0.69699,
32
+ "recall_at_1000": 0.89403,
33
+ "recall_at_3": 0.23442,
34
+ "recall_at_5": 0.30536
35
+ },
36
+ "mteb_dataset_name": "MSMARCO",
37
+ "mteb_version": "1.1.0"
38
+ }
mteb_results/MTOPDomainClassification.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "d80d48c1eb48d3562165c59d59d0034df9fff0bf",
3
+ "mteb_dataset_name": "MTOPDomainClassification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "en": {
7
+ "accuracy": 0.8966712266301871,
8
+ "accuracy_stderr": 0.009523011920085962,
9
+ "f1": 0.8957660424361247,
10
+ "f1_stderr": 0.009247170021662966,
11
+ "main_score": 0.8966712266301871
12
+ },
13
+ "evaluation_time": 7.75
14
+ },
15
+ "validation": {
16
+ "en": {
17
+ "accuracy": 0.9017002237136464,
18
+ "accuracy_stderr": 0.009890167527403295,
19
+ "f1": 0.9039792204701363,
20
+ "f1_stderr": 0.009182351003334687,
21
+ "main_score": 0.9017002237136464
22
+ },
23
+ "evaluation_time": 4.88
24
+ }
25
+ }
mteb_results/MTOPIntentClassification.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba",
3
+ "mteb_dataset_name": "MTOPIntentClassification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "en": {
7
+ "accuracy": 0.6028499772001825,
8
+ "accuracy_stderr": 0.018495543127366038,
9
+ "f1": 0.40306374001528233,
10
+ "f1_stderr": 0.011859407815520086,
11
+ "main_score": 0.6028499772001825
12
+ },
13
+ "evaluation_time": 30.96
14
+ },
15
+ "validation": {
16
+ "en": {
17
+ "accuracy": 0.6150335570469799,
18
+ "accuracy_stderr": 0.01903139236025276,
19
+ "f1": 0.4147129810603558,
20
+ "f1_stderr": 0.015560901035463594,
21
+ "main_score": 0.6150335570469799
22
+ },
23
+ "evaluation_time": 28.07
24
+ }
25
+ }
mteb_results/MassiveIntentClassification.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "31efe3c427b0bae9c22cbb560b8f15491cc6bed7",
3
+ "mteb_dataset_name": "MassiveIntentClassification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "en": {
7
+ "accuracy": 0.6333557498318763,
8
+ "accuracy_stderr": 0.014612806300514952,
9
+ "f1": 0.6024039910680179,
10
+ "f1_stderr": 0.012256367770368185,
11
+ "main_score": 0.6333557498318763
12
+ },
13
+ "evaluation_time": 22.52
14
+ },
15
+ "validation": {
16
+ "en": {
17
+ "accuracy": 0.6426955238563699,
18
+ "accuracy_stderr": 0.01633350887848132,
19
+ "f1": 0.5828069832892886,
20
+ "f1_stderr": 0.013604921852646317,
21
+ "main_score": 0.6426955238563699
22
+ },
23
+ "evaluation_time": 17.6
24
+ }
25
+ }
mteb_results/MassiveScenarioClassification.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "7d571f92784cd94a019292a1f45445077d0ef634",
3
+ "mteb_dataset_name": "MassiveScenarioClassification",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "en": {
7
+ "accuracy": 0.7237390719569603,
8
+ "accuracy_stderr": 0.006043481355389665,
9
+ "f1": 0.7233097333477316,
10
+ "f1_stderr": 0.0075559844507943974,
11
+ "main_score": 0.7237390719569603
12
+ },
13
+ "evaluation_time": 6.86
14
+ },
15
+ "validation": {
16
+ "en": {
17
+ "accuracy": 0.7321200196753566,
18
+ "accuracy_stderr": 0.010745609148770754,
19
+ "f1": 0.7288011677053199,
20
+ "f1_stderr": 0.010826173990376636,
21
+ "main_score": 0.7321200196753566
22
+ },
23
+ "evaluation_time": 5.61
24
+ }
25
+ }
mteb_results/MedrxivClusteringP2P.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "e7a26af6f3ae46b30dde8737f02c07b1505bcc73",
3
+ "mteb_dataset_name": "MedrxivClusteringP2P",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 226.37,
7
+ "v_measure": 0.34681589390605516,
8
+ "v_measure_std": 0.01515645822647098
9
+ }
10
+ }
mteb_results/MedrxivClusteringS2S.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "35191c8c0dca72d8ff3efcd72aa802307d469663",
3
+ "mteb_dataset_name": "MedrxivClusteringS2S",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 31.13,
7
+ "v_measure": 0.30340061711905236,
8
+ "v_measure_std": 0.012579424998938571
9
+ }
10
+ }
mteb_results/MindSmallReranking.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "3bdac13927fdc888b903db93b2ffdbd90b295a69",
3
+ "mteb_dataset_name": "MindSmallReranking",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 1849.79,
7
+ "map": 0.32018143262958026,
8
+ "mrr": 0.33205552400553673
9
+ }
10
+ }
mteb_results/NFCorpus.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "NFCorpus",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 17.47,
7
+ "map_at_1": 0.03391,
8
+ "map_at_10": 0.07722,
9
+ "map_at_100": 0.10286,
10
+ "map_at_1000": 0.11668,
11
+ "map_at_3": 0.05552,
12
+ "map_at_5": 0.06468,
13
+ "mrr_at_1": 0.34365,
14
+ "mrr_at_10": 0.42555,
15
+ "mrr_at_100": 0.43295,
16
+ "mrr_at_1000": 0.43357,
17
+ "mrr_at_3": 0.40299,
18
+ "mrr_at_5": 0.41182,
19
+ "ndcg_at_1": 0.31424,
20
+ "ndcg_at_10": 0.24758,
21
+ "ndcg_at_100": 0.23678,
22
+ "ndcg_at_1000": 0.33377,
23
+ "ndcg_at_3": 0.28302,
24
+ "ndcg_at_5": 0.26342,
25
+ "precision_at_1": 0.33437,
26
+ "precision_at_10": 0.19257,
27
+ "precision_at_100": 0.06663,
28
+ "precision_at_1000": 0.0199,
29
+ "precision_at_3": 0.27761,
30
+ "precision_at_5": 0.23715,
31
+ "recall_at_1": 0.03391,
32
+ "recall_at_10": 0.11068,
33
+ "recall_at_100": 0.25878,
34
+ "recall_at_1000": 0.6019,
35
+ "recall_at_3": 0.06169,
36
+ "recall_at_5": 0.07767
37
+ }
38
+ }
mteb_results/NQ.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "NQ",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 7686.43,
7
+ "map_at_1": 0.15168,
8
+ "map_at_10": 0.26177,
9
+ "map_at_100": 0.27564,
10
+ "map_at_1000": 0.27629,
11
+ "map_at_3": 0.2203,
12
+ "map_at_5": 0.24276,
13
+ "mrr_at_1": 0.17439,
14
+ "mrr_at_10": 0.28205,
15
+ "mrr_at_100": 0.29357,
16
+ "mrr_at_1000": 0.29408,
17
+ "mrr_at_3": 0.24377,
18
+ "mrr_at_5": 0.2654,
19
+ "ndcg_at_1": 0.1741,
20
+ "ndcg_at_10": 0.32936,
21
+ "ndcg_at_100": 0.39197,
22
+ "ndcg_at_1000": 0.40892,
23
+ "ndcg_at_3": 0.24721,
24
+ "ndcg_at_5": 0.28615,
25
+ "precision_at_1": 0.1741,
26
+ "precision_at_10": 0.06199,
27
+ "precision_at_100": 0.00969,
28
+ "precision_at_1000": 0.00113,
29
+ "precision_at_3": 0.1179,
30
+ "precision_at_5": 0.09264,
31
+ "recall_at_1": 0.15168,
32
+ "recall_at_10": 0.51914,
33
+ "recall_at_100": 0.79804,
34
+ "recall_at_1000": 0.9276,
35
+ "recall_at_3": 0.30212,
36
+ "recall_at_5": 0.39204
37
+ }
38
+ }
mteb_results/QuoraRetrieval.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "QuoraRetrieval",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 258.94,
7
+ "map_at_1": 0.67306,
8
+ "map_at_10": 0.80634,
9
+ "map_at_100": 0.81349,
10
+ "map_at_1000": 0.81373,
11
+ "map_at_3": 0.77691,
12
+ "map_at_5": 0.79512,
13
+ "mrr_at_1": 0.7756,
14
+ "mrr_at_10": 0.84177,
15
+ "mrr_at_100": 0.8435,
16
+ "mrr_at_1000": 0.84353,
17
+ "mrr_at_3": 0.83003,
18
+ "mrr_at_5": 0.83799,
19
+ "ndcg_at_1": 0.7758,
20
+ "ndcg_at_10": 0.84782,
21
+ "ndcg_at_100": 0.86443,
22
+ "ndcg_at_1000": 0.86654,
23
+ "ndcg_at_3": 0.8167,
24
+ "ndcg_at_5": 0.83356,
25
+ "precision_at_1": 0.7758,
26
+ "precision_at_10": 0.12875,
27
+ "precision_at_100": 0.01503,
28
+ "precision_at_1000": 0.00156,
29
+ "precision_at_3": 0.3563,
30
+ "precision_at_5": 0.23484,
31
+ "recall_at_1": 0.67306,
32
+ "recall_at_10": 0.9264,
33
+ "recall_at_100": 0.98681,
34
+ "recall_at_1000": 0.9979,
35
+ "recall_at_3": 0.83682,
36
+ "recall_at_5": 0.88424
37
+ }
38
+ }
mteb_results/RedditClustering.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "24640382cdbf8abc73003fb0fa6d111a705499eb",
3
+ "mteb_dataset_name": "RedditClustering",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 237.53,
7
+ "v_measure": 0.5076319866126382,
8
+ "v_measure_std": 0.04676162821389071
9
+ }
10
+ }
mteb_results/RedditClusteringP2P.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "282350215ef01743dc01b456c7f5241fa8937f16",
3
+ "mteb_dataset_name": "RedditClusteringP2P",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 1202.29,
7
+ "v_measure": 0.55024711941649,
8
+ "v_measure_std": 0.12775990781233748
9
+ }
10
+ }
mteb_results/SCIDOCS.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": null,
3
+ "mteb_dataset_name": "SCIDOCS",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "evaluation_time": 82.79,
7
+ "map_at_1": 0.03938,
8
+ "map_at_10": 0.08817,
9
+ "map_at_100": 0.10547,
10
+ "map_at_1000": 0.10852,
11
+ "map_at_3": 0.06352,
12
+ "map_at_5": 0.07453,
13
+ "mrr_at_1": 0.194,
14
+ "mrr_at_10": 0.27371,
15
+ "mrr_at_100": 0.28672,
16
+ "mrr_at_1000": 0.28747,
17
+ "mrr_at_3": 0.24583,
18
+ "mrr_at_5": 0.26143,
19
+ "ndcg_at_1": 0.194,
20
+ "ndcg_at_10": 0.15264,
21
+ "ndcg_at_100": 0.2263,
22
+ "ndcg_at_1000": 0.28559,
23
+ "ndcg_at_3": 0.14425,
24
+ "ndcg_at_5": 0.1252,
25
+ "precision_at_1": 0.194,
26
+ "precision_at_10": 0.0781,
27
+ "precision_at_100": 0.01854,
28
+ "precision_at_1000": 0.00329,
29
+ "precision_at_3": 0.131,
30
+ "precision_at_5": 0.1068,
31
+ "recall_at_1": 0.03938,
32
+ "recall_at_10": 0.15903,
33
+ "recall_at_100": 0.37645,
34
+ "recall_at_1000": 0.6686,
35
+ "recall_at_3": 0.07993,
36
+ "recall_at_5": 0.10885
37
+ }
38
+ }
mteb_results/SICK-R.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "a6ea5a8cab320b040a23452cc28066d9beae2cee",
3
+ "mteb_dataset_name": "SICK-R",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "cos_sim": {
7
+ "pearson": 0.8012689060151424,
8
+ "spearman": 0.7046515535094772
9
+ },
10
+ "euclidean": {
11
+ "pearson": 0.7717160003557223,
12
+ "spearman": 0.704651757047438
13
+ },
14
+ "evaluation_time": 7.91,
15
+ "manhattan": {
16
+ "pearson": 0.7718129609281936,
17
+ "spearman": 0.7046610403752913
18
+ }
19
+ }
20
+ }
mteb_results/STS12.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "a0d554a64d88156834ff5ae9920b964011b16384",
3
+ "mteb_dataset_name": "STS12",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "cos_sim": {
7
+ "pearson": 0.70451157033355,
8
+ "spearman": 0.6399899601697853
9
+ },
10
+ "euclidean": {
11
+ "pearson": 0.6746985359967678,
12
+ "spearman": 0.6400001637764805
13
+ },
14
+ "evaluation_time": 2.34,
15
+ "manhattan": {
16
+ "pearson": 0.6756534741780037,
17
+ "spearman": 0.6406533893575366
18
+ }
19
+ }
20
+ }
mteb_results/STS13.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_revision": "7e90230a92c190f1bf69ae9002b8cea547a64cca",
3
+ "mteb_dataset_name": "STS13",
4
+ "mteb_version": "1.1.0",
5
+ "test": {
6
+ "cos_sim": {
7
+ "pearson": 0.7765086614464292,
8
+ "spearman": 0.7820169706921849
9
+ },
10
+ "euclidean": {
11
+ "pearson": 0.7777758172155284,
12
+ "spearman": 0.7820169706921849
13
+ },
14
+ "evaluation_time": 1.03,
15
+ "manhattan": {
16
+ "pearson": 0.7775077884860052,
17
+ "spearman": 0.7816875216484164
18
+ }
19
+ }
20
+ }