dylanAtHum
commited on
Initial Commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- 1_Pooling/config.json +4 -0
- Data_Records.ipynb +92 -0
- Dataloading.ipynb +675 -0
- README.md +1937 -0
- Replication.txt +40 -0
- Training.py +465 -0
- bert_layers.py +1072 -0
- bert_padding.py +159 -0
- config.json +33 -0
- config_sentence_transformers.json +7 -0
- configuration_bert.py +25 -0
- data_records.json +1 -0
- flash_attn_triton.py +1112 -0
- modules.json +20 -0
- mteb_results/AmazonCounterfactualClassification.json +29 -0
- mteb_results/AmazonPolarityClassification.json +15 -0
- mteb_results/AmazonReviewsClassification.json +25 -0
- mteb_results/ArguAna.json +38 -0
- mteb_results/ArxivClusteringP2P.json +10 -0
- mteb_results/ArxivClusteringS2S.json +10 -0
- mteb_results/AskUbuntuDupQuestions.json +10 -0
- mteb_results/BIOSSES.json +20 -0
- mteb_results/Banking77Classification.json +13 -0
- mteb_results/BiorxivClusteringP2P.json +10 -0
- mteb_results/BiorxivClusteringS2S.json +10 -0
- mteb_results/CQADupstackEnglishRetrieval.json +38 -0
- mteb_results/ClimateFEVER.json +38 -0
- mteb_results/DBPedia.json +38 -0
- mteb_results/EmotionClassification.json +21 -0
- mteb_results/FEVER.json +38 -0
- mteb_results/FiQA2018.json +38 -0
- mteb_results/HotpotQA.json +38 -0
- mteb_results/ImdbClassification.json +15 -0
- mteb_results/MSMARCO.json +38 -0
- mteb_results/MTOPDomainClassification.json +25 -0
- mteb_results/MTOPIntentClassification.json +25 -0
- mteb_results/MassiveIntentClassification.json +25 -0
- mteb_results/MassiveScenarioClassification.json +25 -0
- mteb_results/MedrxivClusteringP2P.json +10 -0
- mteb_results/MedrxivClusteringS2S.json +10 -0
- mteb_results/MindSmallReranking.json +10 -0
- mteb_results/NFCorpus.json +38 -0
- mteb_results/NQ.json +38 -0
- mteb_results/QuoraRetrieval.json +38 -0
- mteb_results/RedditClustering.json +10 -0
- mteb_results/RedditClusteringP2P.json +10 -0
- mteb_results/SCIDOCS.json +38 -0
- mteb_results/SICK-R.json +20 -0
- mteb_results/STS12.json +20 -0
- mteb_results/STS13.json +20 -0
1_Pooling/config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 768,
|
3 |
+
"pooling_mode_mean_tokens": true
|
4 |
+
}
|
Data_Records.ipynb
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "e66bbb77-71f5-4d80-b766-f67144ea7a93",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Data Records\n",
|
9 |
+
"\n",
|
10 |
+
"## This notebook generates the data_records.json file where each entry in the resulting dictionary follows the form {filename: num_records} for every dataset we will use during training"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 39,
|
16 |
+
"id": "74ad6613-44ff-435e-8550-df993e915677",
|
17 |
+
"metadata": {
|
18 |
+
"tags": []
|
19 |
+
},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"# import relevant libraries\n",
|
23 |
+
"import os\n",
|
24 |
+
"import boto3\n",
|
25 |
+
"import json\n",
|
26 |
+
"from smart_open import open"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"id": "e2d53761-da0e-44f4-8a3e-1285bf810b03",
|
33 |
+
"metadata": {
|
34 |
+
"tags": []
|
35 |
+
},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"s3 = boto3.resource('s3')\n",
|
39 |
+
"my_bucket = s3.Bucket('lodestone-rnd')\n",
|
40 |
+
"\n",
|
41 |
+
"# collect all filenames from the data/ directory of the lodestone-rnd S3 bucket\n",
|
42 |
+
"files = [\"\"]*((621+12+9+36)+1)\n",
|
43 |
+
"for i, object_summary in enumerate(my_bucket.objects.filter(Prefix=\"data/\")):\n",
|
44 |
+
" files[i] = object_summary.key[5:]\n",
|
45 |
+
"files = files[1:]\n",
|
46 |
+
"files = [file for file in files if file != 'cnn_dailymail_splitted.json.gz']\n",
|
47 |
+
"\n",
|
48 |
+
"s3_client = boto3.client(\"s3\")\n",
|
49 |
+
"\n",
|
50 |
+
"# for each training dataset, store the number of records in a dictionary with the following form {filename: num_records}\n",
|
51 |
+
"data_lengths = {}\n",
|
52 |
+
"for file in files:\n",
|
53 |
+
" source_uri = f's3://lodestone-rnd/data/{file}'\n",
|
54 |
+
" # S2ORC_citations_abstracts.json.gz and amazon-qa.json.gz must be handled differently since each line in their training\n",
|
55 |
+
" # data is split into multiple records due to the fact that each query has multiple positive pair responses\n",
|
56 |
+
" if file in ['S2ORC_citations_abstracts.json.gz','amazon-qa.json.gz']:\n",
|
57 |
+
" length = 0\n",
|
58 |
+
" for json_line in open(source_uri, transport_params={\"client\": s3_client}):\n",
|
59 |
+
" data = json.loads(json_line.strip())\n",
|
60 |
+
" length += len(data['pos'])\n",
|
61 |
+
" else:\n",
|
62 |
+
" length = int(os.popen(f'aws s3 cp {source_uri} - | zcat | wc -l').read().rstrip())\n",
|
63 |
+
" data_lengths[f'{file}'] = length\n",
|
64 |
+
" \n",
|
65 |
+
"# write the resulting dictionary to a .json file for future use during training\n",
|
66 |
+
"with open('data_records.json', 'w') as fileout:\n",
|
67 |
+
" json.dump(data_lengths, fileout)"
|
68 |
+
]
|
69 |
+
}
|
70 |
+
],
|
71 |
+
"metadata": {
|
72 |
+
"kernelspec": {
|
73 |
+
"display_name": "conda_pytorch_p310",
|
74 |
+
"language": "python",
|
75 |
+
"name": "conda_pytorch_p310"
|
76 |
+
},
|
77 |
+
"language_info": {
|
78 |
+
"codemirror_mode": {
|
79 |
+
"name": "ipython",
|
80 |
+
"version": 3
|
81 |
+
},
|
82 |
+
"file_extension": ".py",
|
83 |
+
"mimetype": "text/x-python",
|
84 |
+
"name": "python",
|
85 |
+
"nbconvert_exporter": "python",
|
86 |
+
"pygments_lexer": "ipython3",
|
87 |
+
"version": "3.10.10"
|
88 |
+
}
|
89 |
+
},
|
90 |
+
"nbformat": 4,
|
91 |
+
"nbformat_minor": 5
|
92 |
+
}
|
Dataloading.ipynb
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "3f9ce240-fd1a-4550-83c7-8cf9658b1d3a",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Dataloading (1B+ Training Pairs)\n",
|
9 |
+
"\n",
|
10 |
+
"## This notebook collects and uploads all 50 relevant sentence embedding datasets to S3 as .json.gz files where each line contains one training record"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 62,
|
16 |
+
"id": "d7af8d0e-b3ed-4007-abdd-5952d775e119",
|
17 |
+
"metadata": {
|
18 |
+
"tags": []
|
19 |
+
},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"import os\n",
|
23 |
+
"import pandas as pd\n",
|
24 |
+
"\n",
|
25 |
+
"os.chdir('/home/ec2-user/SageMaker')"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 63,
|
31 |
+
"id": "686699fb-d2c9-4653-afca-580aef343451",
|
32 |
+
"metadata": {
|
33 |
+
"collapsed": true,
|
34 |
+
"jupyter": {
|
35 |
+
"outputs_hidden": true
|
36 |
+
},
|
37 |
+
"tags": []
|
38 |
+
},
|
39 |
+
"outputs": [
|
40 |
+
{
|
41 |
+
"name": "stdout",
|
42 |
+
"output_type": "stream",
|
43 |
+
"text": [
|
44 |
+
"Loaded plugins: dkms-build-requires, extras_suggestions, langpacks, priorities,\n",
|
45 |
+
" : update-motd, versionlock\n",
|
46 |
+
"Cleaning repos: amzn2-core amzn2extra-docker amzn2extra-epel\n",
|
47 |
+
" : amzn2extra-kernel-5.10 amzn2extra-python3.8 centos-extras\n",
|
48 |
+
" : copr:copr.fedorainfracloud.org:vbatts:shadow-utils-newxidmap\n",
|
49 |
+
" : docker-ce-stable libnvidia-container neuron\n",
|
50 |
+
"21 metadata files removed\n",
|
51 |
+
"15 sqlite files removed\n",
|
52 |
+
"0 metadata files removed\n",
|
53 |
+
"Loaded plugins: dkms-build-requires, extras_suggestions, langpacks, priorities,\n",
|
54 |
+
" : update-motd, versionlock\n"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"name": "stderr",
|
59 |
+
"output_type": "stream",
|
60 |
+
"text": [
|
61 |
+
"https://download.docker.com/linux/centos/2/x86_64/stable/repodata/repomd.xml: [Errno 14] HTTPS Error 404 - Not Found\n",
|
62 |
+
"Trying other mirror.\n"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"name": "stdout",
|
67 |
+
"output_type": "stream",
|
68 |
+
"text": [
|
69 |
+
"62 packages excluded due to repository priority protections\n",
|
70 |
+
"Resolving Dependencies\n",
|
71 |
+
"--> Running transaction check\n",
|
72 |
+
"---> Package epel-release.noarch 0:7-11 will be installed\n",
|
73 |
+
"--> Finished Dependency Resolution\n",
|
74 |
+
"\n",
|
75 |
+
"Dependencies Resolved\n",
|
76 |
+
"\n",
|
77 |
+
"================================================================================\n",
|
78 |
+
" Package Arch Version Repository Size\n",
|
79 |
+
"================================================================================\n",
|
80 |
+
"Installing:\n",
|
81 |
+
" epel-release noarch 7-11 amzn2extra-epel 15 k\n",
|
82 |
+
"\n",
|
83 |
+
"Transaction Summary\n",
|
84 |
+
"================================================================================\n",
|
85 |
+
"Install 1 Package\n",
|
86 |
+
"\n",
|
87 |
+
"Total download size: 15 k\n",
|
88 |
+
"Installed size: 24 k\n",
|
89 |
+
"Downloading packages:\n",
|
90 |
+
"Running transaction check\n",
|
91 |
+
"Running transaction test\n",
|
92 |
+
"Transaction test succeeded\n",
|
93 |
+
"Running transaction\n"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"name": "stderr",
|
98 |
+
"output_type": "stream",
|
99 |
+
"text": [
|
100 |
+
"Warning: RPMDB altered outside of yum.\n"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"name": "stdout",
|
105 |
+
"output_type": "stream",
|
106 |
+
"text": [
|
107 |
+
" Installing : epel-release-7-11.noarch 1/1 \n",
|
108 |
+
" Verifying : epel-release-7-11.noarch 1/1 \n",
|
109 |
+
"\n",
|
110 |
+
"Installed:\n",
|
111 |
+
" epel-release.noarch 0:7-11 \n",
|
112 |
+
"\n",
|
113 |
+
"Complete!\n",
|
114 |
+
"Installing epel-release\n",
|
115 |
+
" 0 ansible2 available \\\n",
|
116 |
+
" [ =2.4.2 =2.4.6 =2.8 =stable ]\n",
|
117 |
+
" 2 httpd_modules available [ =1.0 =stable ]\n",
|
118 |
+
" 3 memcached1.5 available \\\n",
|
119 |
+
" [ =1.5.1 =1.5.16 =1.5.17 ]\n",
|
120 |
+
" 6 postgresql10 available [ =10 =stable ]\n",
|
121 |
+
" 9 R3.4 available [ =3.4.3 =stable ]\n",
|
122 |
+
" 10 rust1 available \\\n",
|
123 |
+
" [ =1.22.1 =1.26.0 =1.26.1 =1.27.2 =1.31.0 =1.38.0\n",
|
124 |
+
" =stable ]\n",
|
125 |
+
" 18 libreoffice available \\\n",
|
126 |
+
" [ =5.0.6.2_15 =5.3.6.1 =stable ]\n",
|
127 |
+
" 19 gimp available [ =2.8.22 ]\n",
|
128 |
+
" 20 docker=latest enabled \\\n",
|
129 |
+
" [ =17.12.1 =18.03.1 =18.06.1 =18.09.9 =stable ]\n",
|
130 |
+
" 21 mate-desktop1.x available \\\n",
|
131 |
+
" [ =1.19.0 =1.20.0 =stable ]\n",
|
132 |
+
" 22 GraphicsMagick1.3 available \\\n",
|
133 |
+
" [ =1.3.29 =1.3.32 =1.3.34 =stable ]\n",
|
134 |
+
" 23 tomcat8.5 available \\\n",
|
135 |
+
" [ =8.5.31 =8.5.32 =8.5.38 =8.5.40 =8.5.42 =8.5.50\n",
|
136 |
+
" =stable ]\n",
|
137 |
+
" 24 epel=latest enabled [ =7.11 =stable ]\n",
|
138 |
+
" 25 testing available [ =1.0 =stable ]\n",
|
139 |
+
" 26 ecs available [ =stable ]\n",
|
140 |
+
" 27 corretto8 available \\\n",
|
141 |
+
" [ =1.8.0_192 =1.8.0_202 =1.8.0_212 =1.8.0_222 =1.8.0_232\n",
|
142 |
+
" =1.8.0_242 =stable ]\n",
|
143 |
+
" 29 golang1.11 available \\\n",
|
144 |
+
" [ =1.11.3 =1.11.11 =1.11.13 =stable ]\n",
|
145 |
+
" 30 squid4 available [ =4 =stable ]\n",
|
146 |
+
" 32 lustre2.10 available \\\n",
|
147 |
+
" [ =2.10.5 =2.10.8 =stable ]\n",
|
148 |
+
" 33 java-openjdk11 available [ =11 =stable ]\n",
|
149 |
+
" 34 lynis available [ =stable ]\n",
|
150 |
+
" 36 BCC available [ =0.x =stable ]\n",
|
151 |
+
" 37 mono available [ =5.x =stable ]\n",
|
152 |
+
" 38 nginx1 available [ =stable ]\n",
|
153 |
+
" 40 mock available [ =stable ]\n",
|
154 |
+
" 41 postgresql11 available [ =11 =stable ]\n",
|
155 |
+
" 43 livepatch available [ =stable ]\n",
|
156 |
+
" 44 python3.8=latest enabled [ =stable ]\n",
|
157 |
+
" 45 haproxy2 available [ =stable ]\n",
|
158 |
+
" 46 collectd available [ =stable ]\n",
|
159 |
+
" 47 aws-nitro-enclaves-cli available [ =stable ]\n",
|
160 |
+
" 48 R4 available [ =stable ]\n",
|
161 |
+
" _ kernel-5.4 available [ =stable ]\n",
|
162 |
+
" 50 selinux-ng available [ =stable ]\n",
|
163 |
+
" 51 php8.0 available [ =stable ]\n",
|
164 |
+
" 52 tomcat9 available [ =stable ]\n",
|
165 |
+
" 53 unbound1.13 available [ =stable ]\n",
|
166 |
+
" 54 mariadb10.5 available [ =stable ]\n",
|
167 |
+
" 55 kernel-5.10=latest enabled [ =stable ]\n",
|
168 |
+
" 56 redis6 available [ =stable ]\n",
|
169 |
+
" 57 ruby3.0 available [ =stable ]\n",
|
170 |
+
" 58 postgresql12 available [ =stable ]\n",
|
171 |
+
" 59 postgresql13 available [ =stable ]\n",
|
172 |
+
" 60 mock2 available [ =stable ]\n",
|
173 |
+
" 61 dnsmasq2.85 available [ =stable ]\n",
|
174 |
+
" 62 kernel-5.15 available [ =stable ]\n",
|
175 |
+
" 63 postgresql14 available [ =stable ]\n",
|
176 |
+
" 64 firefox available [ =stable ]\n",
|
177 |
+
" 65 lustre available [ =stable ]\n",
|
178 |
+
" 66 php8.1 available [ =stable ]\n",
|
179 |
+
" 67 awscli1 available [ =stable ]\n",
|
180 |
+
" 68 php8.2 available [ =stable ]\n",
|
181 |
+
" 69 dnsmasq available [ =stable ]\n",
|
182 |
+
" 70 unbound1.17 available [ =stable ]\n",
|
183 |
+
" 71 golang1.19 available [ =stable ]\n",
|
184 |
+
" 72 collectd-python3 available [ =stable ]\n",
|
185 |
+
"Loaded plugins: dkms-build-requires, extras_suggestions, langpacks, priorities,\n",
|
186 |
+
" : update-motd, versionlock\n",
|
187 |
+
"================================== repo: epel ==================================\n",
|
188 |
+
"[epel]\n",
|
189 |
+
"async = True\n",
|
190 |
+
"bandwidth = 0\n",
|
191 |
+
"base_persistdir = /var/lib/yum/repos/x86_64/2\n",
|
192 |
+
"baseurl = \n",
|
193 |
+
"cache = 0\n",
|
194 |
+
"cachedir = /var/cache/yum/x86_64/2/epel\n",
|
195 |
+
"check_config_file_age = True\n",
|
196 |
+
"compare_providers_priority = 80\n",
|
197 |
+
"cost = 1000\n",
|
198 |
+
"deltarpm_metadata_percentage = 100\n",
|
199 |
+
"deltarpm_percentage = \n",
|
200 |
+
"enabled = True\n",
|
201 |
+
"enablegroups = True\n",
|
202 |
+
"exclude = \n",
|
203 |
+
"failovermethod = priority\n",
|
204 |
+
"ftp_disable_epsv = False\n",
|
205 |
+
"gpgcadir = /var/lib/yum/repos/x86_64/2/epel/gpgcadir\n",
|
206 |
+
"gpgcakey = \n",
|
207 |
+
"gpgcheck = True\n",
|
208 |
+
"gpgdir = /var/lib/yum/repos/x86_64/2/epel/gpgdir\n",
|
209 |
+
"gpgkey = file:///etc/pki/rpm-gpg/RPM-GPG-KEY-EPEL-7\n",
|
210 |
+
"hdrdir = /var/cache/yum/x86_64/2/epel/headers\n",
|
211 |
+
"http_caching = all\n",
|
212 |
+
"includepkgs = \n",
|
213 |
+
"ip_resolve = \n",
|
214 |
+
"keepalive = True\n",
|
215 |
+
"keepcache = False\n",
|
216 |
+
"mddownloadpolicy = sqlite\n",
|
217 |
+
"mdpolicy = group:small\n",
|
218 |
+
"mediaid = \n",
|
219 |
+
"metadata_expire = 21600\n",
|
220 |
+
"metadata_expire_filter = read-only:present\n",
|
221 |
+
"metalink = https://mirrors.fedoraproject.org/metalink?repo=epel-7&arch=x86_64\n",
|
222 |
+
"minrate = 0\n",
|
223 |
+
"mirrorlist = \n",
|
224 |
+
"mirrorlist_expire = 86400\n",
|
225 |
+
"name = Extra Packages for Enterprise Linux 7 - x86_64\n",
|
226 |
+
"old_base_cache_dir = \n",
|
227 |
+
"password = \n",
|
228 |
+
"persistdir = /var/lib/yum/repos/x86_64/2/epel\n",
|
229 |
+
"pkgdir = /var/cache/yum/x86_64/2/epel/packages\n",
|
230 |
+
"priority = 99\n",
|
231 |
+
"proxy = False\n",
|
232 |
+
"proxy_dict = \n",
|
233 |
+
"proxy_password = \n",
|
234 |
+
"proxy_username = \n",
|
235 |
+
"repo_gpgcheck = False\n",
|
236 |
+
"report_instanceid = False\n",
|
237 |
+
"retries = 7\n",
|
238 |
+
"skip_if_unavailable = False\n",
|
239 |
+
"ssl_check_cert_permissions = True\n",
|
240 |
+
"sslcacert = \n",
|
241 |
+
"sslclientcert = \n",
|
242 |
+
"sslclientkey = \n",
|
243 |
+
"sslverify = True\n",
|
244 |
+
"throttle = 0\n",
|
245 |
+
"timeout = 5.0\n",
|
246 |
+
"ui_id = epel/x86_64\n",
|
247 |
+
"ui_repoid_vars = releasever,\n",
|
248 |
+
" basearch\n",
|
249 |
+
"username = \n",
|
250 |
+
"\n",
|
251 |
+
"Loaded plugins: dkms-build-requires, extras_suggestions, langpacks, priorities,\n",
|
252 |
+
" : update-motd, versionlock\n"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"name": "stderr",
|
257 |
+
"output_type": "stream",
|
258 |
+
"text": [
|
259 |
+
"https://download.docker.com/linux/centos/2/x86_64/stable/repodata/repomd.xml: [Errno 14] HTTPS Error 404 - Not Found\n",
|
260 |
+
"Trying other mirror.\n",
|
261 |
+
"http://mirror.es.its.nyu.edu/epel/7/x86_64/repodata/repomd.xml: [Errno 12] Timeout on http://mirror.es.its.nyu.edu/epel/7/x86_64/repodata/repomd.xml: (28, 'Failed to connect to mirror.es.its.nyu.edu port 80 after 5001 ms: Timeout was reached')\n",
|
262 |
+
"Trying other mirror.\n"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"name": "stdout",
|
267 |
+
"output_type": "stream",
|
268 |
+
"text": [
|
269 |
+
"286 packages excluded due to repository priority protections\n",
|
270 |
+
"Resolving Dependencies\n",
|
271 |
+
"--> Running transaction check\n",
|
272 |
+
"---> Package git-lfs.x86_64 0:2.10.0-2.el7 will be installed\n",
|
273 |
+
"--> Finished Dependency Resolution\n",
|
274 |
+
"\n",
|
275 |
+
"Dependencies Resolved\n",
|
276 |
+
"\n",
|
277 |
+
"================================================================================\n",
|
278 |
+
" Package Arch Version Repository Size\n",
|
279 |
+
"================================================================================\n",
|
280 |
+
"Installing:\n",
|
281 |
+
" git-lfs x86_64 2.10.0-2.el7 epel 3.7 M\n",
|
282 |
+
"\n",
|
283 |
+
"Transaction Summary\n",
|
284 |
+
"================================================================================\n",
|
285 |
+
"Install 1 Package\n",
|
286 |
+
"\n",
|
287 |
+
"Total download size: 3.7 M\n",
|
288 |
+
"Installed size: 13 M\n",
|
289 |
+
"Downloading packages:\n"
|
290 |
+
]
|
291 |
+
},
|
292 |
+
{
|
293 |
+
"name": "stderr",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"warning: /var/cache/yum/x86_64/2/epel/packages/git-lfs-2.10.0-2.el7.x86_64.rpm: Header V4 RSA/SHA256 Signature, key ID 352c64e5: NOKEY\n",
|
297 |
+
"Importing GPG key 0x352C64E5:\n",
|
298 |
+
" Userid : \"Fedora EPEL (7) <epel@fedoraproject.org>\"\n",
|
299 |
+
" Fingerprint: 91e9 7d7c 4a5e 96f1 7f3e 888f 6a2f aea2 352c 64e5\n",
|
300 |
+
" Package : epel-release-7-11.noarch (@amzn2extra-epel)\n",
|
301 |
+
" From : /etc/pki/rpm-gpg/RPM-GPG-KEY-EPEL-7\n"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"name": "stdout",
|
306 |
+
"output_type": "stream",
|
307 |
+
"text": [
|
308 |
+
"Public key for git-lfs-2.10.0-2.el7.x86_64.rpm is not installed\n",
|
309 |
+
"Retrieving key from file:///etc/pki/rpm-gpg/RPM-GPG-KEY-EPEL-7\n",
|
310 |
+
"Running transaction check\n",
|
311 |
+
"Running transaction test\n",
|
312 |
+
"Transaction test succeeded\n",
|
313 |
+
"Running transaction\n",
|
314 |
+
" Installing : git-lfs-2.10.0-2.el7.x86_64 1/1 \n",
|
315 |
+
" Verifying : git-lfs-2.10.0-2.el7.x86_64 1/1 \n",
|
316 |
+
"\n",
|
317 |
+
"Installed:\n",
|
318 |
+
" git-lfs.x86_64 0:2.10.0-2.el7 \n",
|
319 |
+
"\n",
|
320 |
+
"Complete!\n",
|
321 |
+
"Git LFS initialized.\n"
|
322 |
+
]
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"data": {
|
326 |
+
"text/plain": [
|
327 |
+
"0"
|
328 |
+
]
|
329 |
+
},
|
330 |
+
"execution_count": 63,
|
331 |
+
"metadata": {},
|
332 |
+
"output_type": "execute_result"
|
333 |
+
}
|
334 |
+
],
|
335 |
+
"source": [
|
336 |
+
"# install git-lfs\n",
|
337 |
+
"os.system('sudo amazon-linux-extras install epel -y')\n",
|
338 |
+
"os.system('sudo yum-config-manager --enable epel')\n",
|
339 |
+
"os.system('sudo yum install git-lfs -y')\n",
|
340 |
+
"os.system('git lfs install')"
|
341 |
+
]
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"cell_type": "code",
|
345 |
+
"execution_count": 3,
|
346 |
+
"id": "6595a6c5-9fae-4bd4-b096-53475ec98294",
|
347 |
+
"metadata": {
|
348 |
+
"tags": []
|
349 |
+
},
|
350 |
+
"outputs": [
|
351 |
+
{
|
352 |
+
"name": "stderr",
|
353 |
+
"output_type": "stream",
|
354 |
+
"text": [
|
355 |
+
"Cloning into 'stackexchange_title_body_jsonl'...\n",
|
356 |
+
"Cloning into 'stackexchange_titlebody_best_voted_answer_jsonl'...\n",
|
357 |
+
"Cloning into 'stackexchange_title_best_voted_answer_jsonl'...\n",
|
358 |
+
"Cloning into 'stackexchange_titlebody_best_and_down_voted_answer_jsonl'...\n",
|
359 |
+
"Cloning into 'reddit-title-body'...\n",
|
360 |
+
"Cloning into '1B_sentence_embeddings'...\n"
|
361 |
+
]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"data": {
|
365 |
+
"text/plain": [
|
366 |
+
"0"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
"execution_count": 3,
|
370 |
+
"metadata": {},
|
371 |
+
"output_type": "execute_result"
|
372 |
+
}
|
373 |
+
],
|
374 |
+
"source": [
|
375 |
+
"# clone relevant datasets' github repositories\n",
|
376 |
+
"stacks = ['stackexchange_title_body_jsonl', #25.3M \n",
|
377 |
+
" 'stackexchange_titlebody_best_voted_answer_jsonl', #4.75M\n",
|
378 |
+
" 'stackexchange_title_best_voted_answer_jsonl', #4.75M\n",
|
379 |
+
" 'stackexchange_titlebody_best_and_down_voted_answer_jsonl'] #210K\n",
|
380 |
+
"\n",
|
381 |
+
"os.environ['GIT_LFS_SKIP_SMUDGE'] = \"1\"\n",
|
382 |
+
"\n",
|
383 |
+
"# clone stackexchange repos\n",
|
384 |
+
"for stack in stacks:\n",
|
385 |
+
" os.system(f'git clone https://huggingface.co/datasets/flax-sentence-embeddings/{stack}')\n",
|
386 |
+
"# clone reddit repo\n",
|
387 |
+
"os.system('git clone https://huggingface.co/datasets/sentence-transformers/reddit-title-body')\n",
|
388 |
+
"# clone 1B+ sentence embeddings repo (this one is just for reference)\n",
|
389 |
+
"os.system('git clone https://github.com/AntoineSimoulin/1B_sentence_embeddings')"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"cell_type": "code",
|
394 |
+
"execution_count": null,
|
395 |
+
"id": "56438029-0f26-491c-b405-f92fb9caeff7",
|
396 |
+
"metadata": {
|
397 |
+
"collapsed": true,
|
398 |
+
"jupyter": {
|
399 |
+
"outputs_hidden": true
|
400 |
+
},
|
401 |
+
"tags": []
|
402 |
+
},
|
403 |
+
"outputs": [
|
404 |
+
{
|
405 |
+
"name": "stdout",
|
406 |
+
"output_type": "stream",
|
407 |
+
"text": [
|
408 |
+
"Downloading 4 StackExchange GitHub datasets into s3://lodestone-rnd/data/\n",
|
409 |
+
"upload: ./networkengineering.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/networkengineering.stackexchange.com.json.gz\n",
|
410 |
+
"upload: ./emacs.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/emacs.stackexchange.com.json.gz\n",
|
411 |
+
"upload: ./christianity.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/christianity.stackexchange.com.json.gz\n",
|
412 |
+
"upload: ./bitcoin.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/bitcoin.stackexchange.com.json.gz\n",
|
413 |
+
"upload: ./academia.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/academia.stackexchange.com.json.gz\n",
|
414 |
+
"upload: ./music.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/music.stackexchange.com.json.gz\n",
|
415 |
+
"upload: ./biology.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/biology.stackexchange.com.json.gz\n",
|
416 |
+
"upload: ./history.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/history.stackexchange.com.json.gz\n",
|
417 |
+
"upload: ./skeptics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/skeptics.stackexchange.com.json.gz\n",
|
418 |
+
"upload: ./anime.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/anime.stackexchange.com.json.gz\n",
|
419 |
+
"upload: ./quant.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/quant.stackexchange.com.json.gz\n",
|
420 |
+
"upload: ./boardgames.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/boardgames.stackexchange.com.json.gz\n",
|
421 |
+
"upload: ./judaism.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/judaism.stackexchange.com.json.gz\n",
|
422 |
+
"upload: ./travel.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/travel.stackexchange.com.json.gz\n",
|
423 |
+
"upload: ./gaming.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/gaming.stackexchange.com.json.gz\n",
|
424 |
+
"upload: ./webapps.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/webapps.stackexchange.com.json.gz\n",
|
425 |
+
"upload: ./stats.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/stats.stackexchange.com.json.gz\n",
|
426 |
+
"upload: ./law.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/law.stackexchange.com.json.gz\n",
|
427 |
+
"upload: ./scifi.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/scifi.stackexchange.com.json.gz\n",
|
428 |
+
"upload: ./bicycles.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/bicycles.stackexchange.com.json.gz\n",
|
429 |
+
"upload: ./datascience.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/datascience.stackexchange.com.json.gz\n",
|
430 |
+
"upload: ./softwareengineering.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/softwareengineering.stackexchange.com.json.gz\n",
|
431 |
+
"upload: ./islam.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/islam.stackexchange.com.json.gz\n",
|
432 |
+
"upload: ./craftcms.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/craftcms.stackexchange.com.json.gz\n",
|
433 |
+
"upload: ./diy.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/diy.stackexchange.com.json.gz\n",
|
434 |
+
"upload: ./arduino.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/arduino.stackexchange.com.json.gz\n",
|
435 |
+
"upload: ./raspberrypi.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/raspberrypi.stackexchange.com.json.gz\n",
|
436 |
+
"upload: ./wordpress.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/wordpress.stackexchange.com.json.gz\n",
|
437 |
+
"upload: ./dba.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/dba.stackexchange.com.json.gz\n",
|
438 |
+
"upload: ./apple.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/apple.stackexchange.com.json.gz\n",
|
439 |
+
"upload: ./hinduism.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/hinduism.stackexchange.com.json.gz\n",
|
440 |
+
"upload: ./mechanics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/mechanics.stackexchange.com.json.gz\n",
|
441 |
+
"upload: ./gamedev.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/gamedev.stackexchange.com.json.gz\n",
|
442 |
+
"upload: ./writers.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/writers.stackexchange.com.json.gz\n",
|
443 |
+
"upload: ./mathematica.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/mathematica.stackexchange.com.json.gz\n",
|
444 |
+
"upload: ./unix.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/unix.stackexchange.com.json.gz\n",
|
445 |
+
"upload: ./magento.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/magento.stackexchange.com.json.gz\n",
|
446 |
+
"upload: ./ethereum.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/ethereum.stackexchange.com.json.gz\n",
|
447 |
+
"upload: ./electronics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/electronics.stackexchange.com.json.gz\n",
|
448 |
+
"upload: ./cs.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/cs.stackexchange.com.json.gz\n",
|
449 |
+
"upload: ./blender.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/blender.stackexchange.com.json.gz\n",
|
450 |
+
"upload: ./drupal.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/drupal.stackexchange.com.json.gz\n",
|
451 |
+
"upload: ./small_stackexchanges.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/small_stackexchanges.json.gz\n",
|
452 |
+
"upload: ./photo.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/photo.stackexchange.com.json.gz\n",
|
453 |
+
"upload: ./engineering.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/engineering.stackexchange.com.json.gz\n",
|
454 |
+
"upload: ./ux.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/ux.stackexchange.com.json.gz\n",
|
455 |
+
"upload: ./german.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/german.stackexchange.com.json.gz\n",
|
456 |
+
"upload: ./japanese.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/japanese.stackexchange.com.json.gz\n",
|
457 |
+
"upload: ./civicrm.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/civicrm.stackexchange.com.json.gz\n",
|
458 |
+
"upload: ./sharepoint.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/sharepoint.stackexchange.com.json.gz\n",
|
459 |
+
"upload: ./mathoverflow.net.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/mathoverflow.net.json.gz\n",
|
460 |
+
"upload: ./meta.stackoverflow.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/meta.stackoverflow.com.json.gz\n",
|
461 |
+
"upload: ./rpg.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/rpg.stackexchange.com.json.gz\n",
|
462 |
+
"upload: ./crypto.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/crypto.stackexchange.com.json.gz\n",
|
463 |
+
"upload: ./vi.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/vi.stackexchange.com.json.gz\n",
|
464 |
+
"upload: ./graphicdesign.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/graphicdesign.stackexchange.com.json.gz\n",
|
465 |
+
"upload: ./cooking.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/cooking.stackexchange.com.json.gz\n",
|
466 |
+
"upload: ./math.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/math.stackexchange.com.json.gz\n",
|
467 |
+
"upload: ./expressionengine.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/expressionengine.stackexchange.com.json.gz\n",
|
468 |
+
"upload: ./movies.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/movies.stackexchange.com.json.gz\n",
|
469 |
+
"upload: ./salesforce.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/salesforce.stackexchange.com.json.gz\n",
|
470 |
+
"upload: ./physics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/physics.stackexchange.com.json.gz\n",
|
471 |
+
"upload: ./aviation.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/aviation.stackexchange.com.json.gz\n",
|
472 |
+
"upload: ./gardening.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/gardening.stackexchange.com.json.gz\n",
|
473 |
+
"upload: ./english.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/english.stackexchange.com.json.gz\n",
|
474 |
+
"upload: ./askubuntu.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/askubuntu.com.json.gz\n",
|
475 |
+
"upload: ./french.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/french.stackexchange.com.json.gz\n",
|
476 |
+
"upload: ./codereview.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/codereview.stackexchange.com.json.gz\n",
|
477 |
+
"upload: ./softwarerecs.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/softwarerecs.stackexchange.com.json.gz\n",
|
478 |
+
"upload: ./rus.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/rus.stackexchange.com.json.gz\n",
|
479 |
+
"upload: ./money.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/money.stackexchange.com.json.gz\n",
|
480 |
+
"upload: ./philosophy.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/philosophy.stackexchange.com.json.gz\n",
|
481 |
+
"upload: ./chemistry.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/chemistry.stackexchange.com.json.gz\n",
|
482 |
+
"upload: ./meta.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/meta.stackexchange.com.json.gz\n",
|
483 |
+
"upload: ./cstheory.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/cstheory.stackexchange.com.json.gz\n",
|
484 |
+
"upload: ./space.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/space.stackexchange.com.json.gz\n",
|
485 |
+
"upload: ./politics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/politics.stackexchange.com.json.gz\n",
|
486 |
+
"upload: ./ell.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/ell.stackexchange.com.json.gz\n",
|
487 |
+
"upload: ./puzzling.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/puzzling.stackexchange.com.json.gz\n",
|
488 |
+
"upload: ./astronomy.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/astronomy.stackexchange.com.json.gz\n",
|
489 |
+
"upload: ./worldbuilding.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/worldbuilding.stackexchange.com.json.gz\n",
|
490 |
+
"upload: ./economics.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/economics.stackexchange.com.json.gz\n",
|
491 |
+
"upload: ./workplace.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/workplace.stackexchange.com.json.gz\n",
|
492 |
+
"upload: ./tex.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/tex.stackexchange.com.json.gz\n",
|
493 |
+
"upload: ./android.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/android.stackexchange.com.json.gz\n",
|
494 |
+
"upload: ./gis.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/gis.stackexchange.com.json.gz\n",
|
495 |
+
"upload: ./dsp.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/dsp.stackexchange.com.json.gz\n",
|
496 |
+
"upload: ./superuser.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_body_jsonl/superuser.com.json.gz\n",
|
497 |
+
"upload: ./english.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/english.stackexchange.com.json.gz\n",
|
498 |
+
"upload: ./meta.serverfault.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/meta.serverfault.com.json.gz\n",
|
499 |
+
"upload: ./scicomp.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/scicomp.stackexchange.com.json.gz\n",
|
500 |
+
"upload: ./askubuntu.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/askubuntu.com.json.gz\n",
|
501 |
+
"upload: ./french.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/french.stackexchange.com.json.gz\n",
|
502 |
+
"upload: ./coffee.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/coffee.stackexchange.com.json.gz\n",
|
503 |
+
"upload: ./codereview.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/codereview.stackexchange.com.json.gz\n",
|
504 |
+
"upload: ./sound.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/sound.stackexchange.com.json.gz\n",
|
505 |
+
"upload: ./opensource.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/opensource.stackexchange.com.json.gz\n",
|
506 |
+
"upload: ./woodworking.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/woodworking.stackexchange.com.json.gz\n",
|
507 |
+
"upload: ./outdoors.stackexchange.com.jsonl.gz to s3://lodestone-rnd/data/stackexchange_title_best_voted_answer_jsonl/outdoors.stackexchange.com.json.gz\n",
|
508 |
+
"upload: ./reddit_title_text_2018.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2018.json.gz\n",
|
509 |
+
"upload: ./reddit_title_text_2011.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2011.json.gz\n",
|
510 |
+
"upload: ./reddit_title_text_2020.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2020.json.gz\n",
|
511 |
+
"upload: ./reddit_title_text_2012.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2012.json.gz\n",
|
512 |
+
"upload: ./reddit_title_text_2021.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2021.json.gz\n",
|
513 |
+
"upload: ./reddit_title_text_2019.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2019.json.gz\n",
|
514 |
+
"upload: ./reddit_title_text_2010.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2010.json.gz\n",
|
515 |
+
"upload: ./reddit_title_text_2014.jsonl.gz to s3://lodestone-rnd/data/reddit-title-body/reddit_title_text_2014.json.gz\n",
|
516 |
+
"\u001b[32mDone\u001b[0m\n",
|
517 |
+
"Total files uploaded: 12\n",
|
518 |
+
"\n",
|
519 |
+
"\n",
|
520 |
+
"Downloading 9 HuggingFace datasets into s3://lodestone-rnd/data/\n",
|
521 |
+
"Downloading dataset S2ORC_citations_abstracts (39,567,485 pairs) ... "
|
522 |
+
]
|
523 |
+
}
|
524 |
+
],
|
525 |
+
"source": [
|
526 |
+
"# DOWNLOAD GITHUB DATASETS (STACKEXCHANGE (https://huggingface.co/flax-sentence-embeddings) & REDDIT (https://huggingface.co/datasets/sentence-transformers/reddit-title-body))\n",
|
527 |
+
"\n",
|
528 |
+
"# these are the files marked as unsafe by HuggingFace when viewing the each of the datasets' pages\n",
|
529 |
+
"unsafe = [[\"serverfault.com.jsonl.gz\", \"security.stackexchange.com.jsonl.gz\"],\n",
|
530 |
+
" [\"monero.stackexchange.com.jsonl.gz\", \"serverfault.com.jsonl.gz\", \"security.stackexchange.com.jsonl.gz\"],\n",
|
531 |
+
" [\"elementaryos.stackexchange.com.jsonl.gz\", \"monero.stackexchange.com.jsonl.gz\", \"security.stackexchange.com.jsonl.gz\"],\n",
|
532 |
+
" [\"\"]]\n",
|
533 |
+
"\n",
|
534 |
+
"file_counts = []\n",
|
535 |
+
"print('Downloading {:,} StackExchange GitHub datasets into s3://lodestone-rnd/data/'.format(len(stacks)))\n",
|
536 |
+
"for i, stack in enumerate(stacks):\n",
|
537 |
+
" # get the names of all the files in the repository that are not unsafe\n",
|
538 |
+
" files = [file for file in os.listdir(f'/home/ec2-user/SageMaker/{stack}') if file.endswith(\".jsonl.gz\")==True if file not in unsafe[i]]\n",
|
539 |
+
" file_counts.append(len(files))\n",
|
540 |
+
" os.chdir(f'/home/ec2-user/SageMaker/{stack}')\n",
|
541 |
+
" print('Downloading dataset {} ({} files) ... '.format(stack, len(files)), end='', flush=True)\n",
|
542 |
+
" # sequentially pull each dataset from git lfs, stream it to S3, and then delete the local copy to free up disk memory\n",
|
543 |
+
" for file_name in files:\n",
|
544 |
+
" os.system(f'git lfs pull --include={file_name}')\n",
|
545 |
+
" os.system(f'aws s3 cp {file_name} s3://lodestone-rnd/data/{stack}/{file_name[:-9] + \".json.gz\"}')\n",
|
546 |
+
" os.remove(file_name)\n",
|
547 |
+
" os.system('rm -r .git/lfs/objects/*')\n",
|
548 |
+
" if len(os.listdir('.git/objects/pack')) == 4:\n",
|
549 |
+
" os.system('ls -t .git/objects/pack/* | head -2 | xargs rm --')\n",
|
550 |
+
" print('\\033[32m' + 'Done' + '\\033[0m')\n",
|
551 |
+
"print(f'Total files uploaded: {sum(file_counts)}')\n",
|
552 |
+
"\n",
|
553 |
+
"print(\"\\n\")\n",
|
554 |
+
"\n",
|
555 |
+
"print('Downloading {:,} Reddit GitHub dataset into s3://lodestone-rnd/data/'.format(1))\n",
|
556 |
+
"# get the names of all the files in the repository\n",
|
557 |
+
"files = [file for file in os.listdir(f'/home/ec2-user/SageMaker/reddit-title-body') if file.endswith(\".jsonl.gz\")==True]\n",
|
558 |
+
"os.chdir(f'/home/ec2-user/SageMaker/reddit-title-body')\n",
|
559 |
+
"print('Downloading dataset {} ({} files) ... '.format(\"reddit-title-body\", len(files)), end='', flush=True)\n",
|
560 |
+
"# sequentially pull each dataset from git lfs, stream it to S3, and then delete the local copy to free up disk memory\n",
|
561 |
+
"for file_name in files:\n",
|
562 |
+
" os.system(f'git lfs pull --include={file_name}')\n",
|
563 |
+
" os.system(f'aws s3 cp {file_name} s3://lodestone-rnd/data/reddit-title-body/{file_name[:-9] + \".json.gz\"}')\n",
|
564 |
+
" os.remove(file_name)\n",
|
565 |
+
" os.system('rm -r .git/lfs/objects/*')\n",
|
566 |
+
" if len(os.listdir('.git/objects/pack')) == 4:\n",
|
567 |
+
" os.system('ls -t .git/objects/pack/* | head -2 | xargs rm --')\n",
|
568 |
+
"print('\\033[32m' + 'Done' + '\\033[0m')\n",
|
569 |
+
"print(f'Total files uploaded: {len(files)}')\n",
|
570 |
+
"\n",
|
571 |
+
"os.chdir('/home/ec2-user/SageMaker')\n",
|
572 |
+
"\n",
|
573 |
+
"print(\"\\n\")\n",
|
574 |
+
"\n",
|
575 |
+
"# DOWNLOAD HUGGINGFACE DATASETS (https://huggingface.co/datasets/sentence-transformers/embedding-training-data)\n",
|
576 |
+
"\n",
|
577 |
+
"# read dataset information from HuggingFace_datasets.tsv\n",
|
578 |
+
"datasets = pd.read_csv(\n",
|
579 |
+
" 'HuggingFace_datasets.tsv',\n",
|
580 |
+
" index_col=0,\n",
|
581 |
+
" sep='\\t',\n",
|
582 |
+
" dtype={\n",
|
583 |
+
" 'Description': str,\n",
|
584 |
+
" 'Size (#Pairs)': str,\n",
|
585 |
+
" 'Performance': float,\n",
|
586 |
+
" 'Download link': str,\n",
|
587 |
+
" 'Source': str})\n",
|
588 |
+
"datasets['Size (#Pairs)'] = datasets['Size (#Pairs)'].str.replace(',', '').astype(int)\n",
|
589 |
+
"datasets = datasets.to_dict(orient='index')\n",
|
590 |
+
"\n",
|
591 |
+
"print('Downloading {:,} HuggingFace datasets into s3://lodestone-rnd/data/'.format(len(datasets)))\n",
|
592 |
+
"\n",
|
593 |
+
"# stream each of the datasets from the URL provided into S3\n",
|
594 |
+
"# (note that S2ORC_citations_abstracts is larger than 50GB and therefore requires the expected size to be passed into the command line as well)\n",
|
595 |
+
"for d in datasets.keys():\n",
|
596 |
+
" print('Downloading dataset {} ({:,} pairs) ... '.format(d, datasets[d]['Size (#Pairs)']), end='', flush=True)\n",
|
597 |
+
" if d == \"S2ORC_citations_abstracts\":\n",
|
598 |
+
" os.system(f'wget -qO- {datasets[d][\"Download link\"]} | aws s3 cp - s3://lodestone-rnd/data/{d + \".json.gz\"} --expected-size 120259084288')\n",
|
599 |
+
" else:\n",
|
600 |
+
" os.system(f'wget -qO- {datasets[d][\"Download link\"]} | aws s3 cp - s3://lodestone-rnd/data/{d + \".json.gz\"}')\n",
|
601 |
+
" print('\\033[32m' + 'Done' + '\\033[0m')\n",
|
602 |
+
"\n",
|
603 |
+
"print(\"\\n\")\n",
|
604 |
+
"\n",
|
605 |
+
"# DOWNLOAD GOOGLE SHEETS DATASETS (https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0)\n",
|
606 |
+
"\n",
|
607 |
+
"# read dataset information from GoogleSheets_datasets.tsv\n",
|
608 |
+
"datasets = pd.read_csv(\n",
|
609 |
+
" 'GoogleSheets_datasets.tsv',\n",
|
610 |
+
" index_col=0,\n",
|
611 |
+
" sep='\\t',\n",
|
612 |
+
" dtype={\n",
|
613 |
+
" 'Description': str,\n",
|
614 |
+
" 'Size (#Pairs)': str,\n",
|
615 |
+
" 'Performance': float,\n",
|
616 |
+
" 'Download link': str,\n",
|
617 |
+
" 'Source': str})\n",
|
618 |
+
"datasets['Size (#Pairs)'] = datasets['Size (#Pairs)'].str.replace(',', '').astype(int)\n",
|
619 |
+
"datasets = datasets.to_dict(orient='index')\n",
|
620 |
+
"\n",
|
621 |
+
"print('Downloading {:,} 1B+ Google Sheets datasets into s3://lodestone-rnd/data/'.format(len(datasets)))\n",
|
622 |
+
"\n",
|
623 |
+
"# stream each of the datasets from the URL provided into S3\n",
|
624 |
+
"for d in datasets.keys():\n",
|
625 |
+
" print('Downloading dataset {} ({:,} pairs) ... '.format(d, datasets[d]['Size (#Pairs)']), end='', flush=True)\n",
|
626 |
+
" os.system(f'wget -qO- {datasets[d][\"Download link\"]} | aws s3 cp - s3://lodestone-rnd/data/{d + \".json.gz\"}')\n",
|
627 |
+
" print('\\033[32m' + 'Done' + '\\033[0m')\n",
|
628 |
+
"\n",
|
629 |
+
"print(\"\\n\")\n",
|
630 |
+
" \n",
|
631 |
+
"print(f'Successfully downloaded 50 datasets and {621+12+9+36} files into s3://lodestone-rnd/data/')"
|
632 |
+
]
|
633 |
+
},
|
634 |
+
{
|
635 |
+
"cell_type": "code",
|
636 |
+
"execution_count": 1,
|
637 |
+
"id": "216d4564-6b69-4688-94ec-a67287134a2d",
|
638 |
+
"metadata": {
|
639 |
+
"tags": []
|
640 |
+
},
|
641 |
+
"outputs": [],
|
642 |
+
"source": [
|
643 |
+
"# clean up (remove the cloned repositories)\n",
|
644 |
+
"import shutil\n",
|
645 |
+
"shutil.rmtree(\"/home/ec2-user/SageMaker/stackexchange_title_body_jsonl\")\n",
|
646 |
+
"shutil.rmtree(\"/home/ec2-user/SageMaker/stackexchange_titlebody_best_voted_answer_jsonl\")\n",
|
647 |
+
"shutil.rmtree(\"/home/ec2-user/SageMaker/stackexchange_title_best_voted_answer_jsonl\")\n",
|
648 |
+
"shutil.rmtree(\"/home/ec2-user/SageMaker/stackexchange_titlebody_best_and_down_voted_answer_jsonl\")\n",
|
649 |
+
"shutil.rmtree(\"/home/ec2-user/SageMaker/reddit-title-body\")\n",
|
650 |
+
"shutil.rmtree(\"/home/ec2-user/SageMaker/1B_sentence_embeddings\")"
|
651 |
+
]
|
652 |
+
}
|
653 |
+
],
|
654 |
+
"metadata": {
|
655 |
+
"kernelspec": {
|
656 |
+
"display_name": "conda_pytorch_p310",
|
657 |
+
"language": "python",
|
658 |
+
"name": "conda_pytorch_p310"
|
659 |
+
},
|
660 |
+
"language_info": {
|
661 |
+
"codemirror_mode": {
|
662 |
+
"name": "ipython",
|
663 |
+
"version": 3
|
664 |
+
},
|
665 |
+
"file_extension": ".py",
|
666 |
+
"mimetype": "text/x-python",
|
667 |
+
"name": "python",
|
668 |
+
"nbconvert_exporter": "python",
|
669 |
+
"pygments_lexer": "ipython3",
|
670 |
+
"version": "3.10.10"
|
671 |
+
}
|
672 |
+
},
|
673 |
+
"nbformat": 4,
|
674 |
+
"nbformat_minor": 5
|
675 |
+
}
|
README.md
CHANGED
@@ -1,3 +1,1940 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
pipeline_tag: sentence-similarity
|
4 |
+
tags:
|
5 |
+
- sentence-transformers
|
6 |
+
- feature-extraction
|
7 |
+
- sentence-similarity
|
8 |
+
- mteb
|
9 |
+
language: en
|
10 |
+
datasets:
|
11 |
+
- s2orc
|
12 |
+
- flax-sentence-embeddings/stackexchange_title_body_jsonl
|
13 |
+
- flax-sentence-embeddings/stackexchange_titlebody_best_voted_answer_jsonl
|
14 |
+
- flax-sentence-embeddings/stackexchange_title_best_voted_answer_jsonl
|
15 |
+
- flax-sentence-embeddings/stackexchange_titlebody_best_and_down_voted_answer_jsonl
|
16 |
+
- sentence-transformers/reddit-title-body
|
17 |
+
- msmarco
|
18 |
+
- gooaq
|
19 |
+
- yahoo_answers_topics
|
20 |
+
- code_search_net
|
21 |
+
- search_qa
|
22 |
+
- eli5
|
23 |
+
- snli
|
24 |
+
- multi_nli
|
25 |
+
- wikihow
|
26 |
+
- natural_questions
|
27 |
+
- trivia_qa
|
28 |
+
- embedding-data/sentence-compression
|
29 |
+
- embedding-data/flickr30k-captions
|
30 |
+
- embedding-data/altlex
|
31 |
+
- embedding-data/simple-wiki
|
32 |
+
- embedding-data/QQP
|
33 |
+
- embedding-data/SPECTER
|
34 |
+
- embedding-data/PAQ_pairs
|
35 |
+
- embedding-data/WikiAnswers
|
36 |
+
- sentence-transformers/embedding-training-data
|
37 |
+
model-index:
|
38 |
+
- name: lodestone-base-4096-v1
|
39 |
+
results:
|
40 |
+
- task:
|
41 |
+
type: Classification
|
42 |
+
dataset:
|
43 |
+
type: mteb/amazon_counterfactual
|
44 |
+
name: MTEB AmazonCounterfactualClassification (en)
|
45 |
+
config: en
|
46 |
+
split: test
|
47 |
+
revision: e8379541af4e31359cca9fbcf4b00f2671dba205
|
48 |
+
metrics:
|
49 |
+
- type: accuracy
|
50 |
+
value: 69.7313432835821
|
51 |
+
- type: ap
|
52 |
+
value: 31.618259511417733
|
53 |
+
- type: f1
|
54 |
+
value: 63.30313825394228
|
55 |
+
- task:
|
56 |
+
type: Classification
|
57 |
+
dataset:
|
58 |
+
type: mteb/amazon_polarity
|
59 |
+
name: MTEB AmazonPolarityClassification
|
60 |
+
config: default
|
61 |
+
split: test
|
62 |
+
revision: e2d317d38cd51312af73b3d32a06d1a08b442046
|
63 |
+
metrics:
|
64 |
+
- type: accuracy
|
65 |
+
value: 86.89837499999999
|
66 |
+
- type: ap
|
67 |
+
value: 82.39500885672128
|
68 |
+
- type: f1
|
69 |
+
value: 86.87317947399657
|
70 |
+
- task:
|
71 |
+
type: Classification
|
72 |
+
dataset:
|
73 |
+
type: mteb/amazon_reviews_multi
|
74 |
+
name: MTEB AmazonReviewsClassification (en)
|
75 |
+
config: en
|
76 |
+
split: test
|
77 |
+
revision: 1399c76144fd37290681b995c656ef9b2e06e26d
|
78 |
+
metrics:
|
79 |
+
- type: accuracy
|
80 |
+
value: 44.05
|
81 |
+
- type: f1
|
82 |
+
value: 42.67624383248947
|
83 |
+
- task:
|
84 |
+
type: Retrieval
|
85 |
+
dataset:
|
86 |
+
type: arguana
|
87 |
+
name: MTEB ArguAna
|
88 |
+
config: default
|
89 |
+
split: test
|
90 |
+
revision: None
|
91 |
+
metrics:
|
92 |
+
- type: map_at_1
|
93 |
+
value: 26.173999999999996
|
94 |
+
- type: map_at_10
|
95 |
+
value: 40.976
|
96 |
+
- type: map_at_100
|
97 |
+
value: 42.067
|
98 |
+
- type: map_at_1000
|
99 |
+
value: 42.075
|
100 |
+
- type: map_at_3
|
101 |
+
value: 35.917
|
102 |
+
- type: map_at_5
|
103 |
+
value: 38.656
|
104 |
+
- type: mrr_at_1
|
105 |
+
value: 26.814
|
106 |
+
- type: mrr_at_10
|
107 |
+
value: 41.252
|
108 |
+
- type: mrr_at_100
|
109 |
+
value: 42.337
|
110 |
+
- type: mrr_at_1000
|
111 |
+
value: 42.345
|
112 |
+
- type: mrr_at_3
|
113 |
+
value: 36.226
|
114 |
+
- type: mrr_at_5
|
115 |
+
value: 38.914
|
116 |
+
- type: ndcg_at_1
|
117 |
+
value: 26.173999999999996
|
118 |
+
- type: ndcg_at_10
|
119 |
+
value: 49.819
|
120 |
+
- type: ndcg_at_100
|
121 |
+
value: 54.403999999999996
|
122 |
+
- type: ndcg_at_1000
|
123 |
+
value: 54.59
|
124 |
+
- type: ndcg_at_3
|
125 |
+
value: 39.231
|
126 |
+
- type: ndcg_at_5
|
127 |
+
value: 44.189
|
128 |
+
- type: precision_at_1
|
129 |
+
value: 26.173999999999996
|
130 |
+
- type: precision_at_10
|
131 |
+
value: 7.838000000000001
|
132 |
+
- type: precision_at_100
|
133 |
+
value: 0.9820000000000001
|
134 |
+
- type: precision_at_1000
|
135 |
+
value: 0.1
|
136 |
+
- type: precision_at_3
|
137 |
+
value: 16.287
|
138 |
+
- type: precision_at_5
|
139 |
+
value: 12.191
|
140 |
+
- type: recall_at_1
|
141 |
+
value: 26.173999999999996
|
142 |
+
- type: recall_at_10
|
143 |
+
value: 78.378
|
144 |
+
- type: recall_at_100
|
145 |
+
value: 98.222
|
146 |
+
- type: recall_at_1000
|
147 |
+
value: 99.644
|
148 |
+
- type: recall_at_3
|
149 |
+
value: 48.862
|
150 |
+
- type: recall_at_5
|
151 |
+
value: 60.953
|
152 |
+
- task:
|
153 |
+
type: Clustering
|
154 |
+
dataset:
|
155 |
+
type: mteb/arxiv-clustering-p2p
|
156 |
+
name: MTEB ArxivClusteringP2P
|
157 |
+
config: default
|
158 |
+
split: test
|
159 |
+
revision: a122ad7f3f0291bf49cc6f4d32aa80929df69d5d
|
160 |
+
metrics:
|
161 |
+
- type: v_measure
|
162 |
+
value: 42.31689035788179
|
163 |
+
- task:
|
164 |
+
type: Clustering
|
165 |
+
dataset:
|
166 |
+
type: mteb/arxiv-clustering-s2s
|
167 |
+
name: MTEB ArxivClusteringS2S
|
168 |
+
config: default
|
169 |
+
split: test
|
170 |
+
revision: f910caf1a6075f7329cdf8c1a6135696f37dbd53
|
171 |
+
metrics:
|
172 |
+
- type: v_measure
|
173 |
+
value: 31.280245136660984
|
174 |
+
- task:
|
175 |
+
type: Reranking
|
176 |
+
dataset:
|
177 |
+
type: mteb/askubuntudupquestions-reranking
|
178 |
+
name: MTEB AskUbuntuDupQuestions
|
179 |
+
config: default
|
180 |
+
split: test
|
181 |
+
revision: 2000358ca161889fa9c082cb41daa8dcfb161a54
|
182 |
+
metrics:
|
183 |
+
- type: map
|
184 |
+
value: 58.79109720839415
|
185 |
+
- type: mrr
|
186 |
+
value: 71.79615705931495
|
187 |
+
- task:
|
188 |
+
type: STS
|
189 |
+
dataset:
|
190 |
+
type: mteb/biosses-sts
|
191 |
+
name: MTEB BIOSSES
|
192 |
+
config: default
|
193 |
+
split: test
|
194 |
+
revision: d3fb88f8f02e40887cd149695127462bbcf29b4a
|
195 |
+
metrics:
|
196 |
+
- type: cos_sim_pearson
|
197 |
+
value: 76.44918756608115
|
198 |
+
- type: cos_sim_spearman
|
199 |
+
value: 70.86607256286257
|
200 |
+
- type: euclidean_pearson
|
201 |
+
value: 74.12154678100815
|
202 |
+
- type: euclidean_spearman
|
203 |
+
value: 70.86607256286257
|
204 |
+
- type: manhattan_pearson
|
205 |
+
value: 74.0078626964417
|
206 |
+
- type: manhattan_spearman
|
207 |
+
value: 70.68353828321327
|
208 |
+
- task:
|
209 |
+
type: Classification
|
210 |
+
dataset:
|
211 |
+
type: mteb/banking77
|
212 |
+
name: MTEB Banking77Classification
|
213 |
+
config: default
|
214 |
+
split: test
|
215 |
+
revision: 0fd18e25b25c072e09e0d92ab615fda904d66300
|
216 |
+
metrics:
|
217 |
+
- type: accuracy
|
218 |
+
value: 75.40584415584415
|
219 |
+
- type: f1
|
220 |
+
value: 74.29514617572676
|
221 |
+
- task:
|
222 |
+
type: Clustering
|
223 |
+
dataset:
|
224 |
+
type: mteb/biorxiv-clustering-p2p
|
225 |
+
name: MTEB BiorxivClusteringP2P
|
226 |
+
config: default
|
227 |
+
split: test
|
228 |
+
revision: 65b79d1d13f80053f67aca9498d9402c2d9f1f40
|
229 |
+
metrics:
|
230 |
+
- type: v_measure
|
231 |
+
value: 37.41860080664014
|
232 |
+
- task:
|
233 |
+
type: Clustering
|
234 |
+
dataset:
|
235 |
+
type: mteb/biorxiv-clustering-s2s
|
236 |
+
name: MTEB BiorxivClusteringS2S
|
237 |
+
config: default
|
238 |
+
split: test
|
239 |
+
revision: 258694dd0231531bc1fd9de6ceb52a0853c6d908
|
240 |
+
metrics:
|
241 |
+
- type: v_measure
|
242 |
+
value: 29.319217023090705
|
243 |
+
- task:
|
244 |
+
type: Retrieval
|
245 |
+
dataset:
|
246 |
+
type: BeIR/cqadupstack
|
247 |
+
name: MTEB CQADupstackEnglishRetrieval
|
248 |
+
config: default
|
249 |
+
split: test
|
250 |
+
revision: None
|
251 |
+
metrics:
|
252 |
+
- type: map_at_1
|
253 |
+
value: 22.528000000000002
|
254 |
+
- type: map_at_10
|
255 |
+
value: 30.751
|
256 |
+
- type: map_at_100
|
257 |
+
value: 31.855
|
258 |
+
- type: map_at_1000
|
259 |
+
value: 31.972
|
260 |
+
- type: map_at_3
|
261 |
+
value: 28.465
|
262 |
+
- type: map_at_5
|
263 |
+
value: 29.738
|
264 |
+
- type: mrr_at_1
|
265 |
+
value: 28.662
|
266 |
+
- type: mrr_at_10
|
267 |
+
value: 35.912
|
268 |
+
- type: mrr_at_100
|
269 |
+
value: 36.726
|
270 |
+
- type: mrr_at_1000
|
271 |
+
value: 36.777
|
272 |
+
- type: mrr_at_3
|
273 |
+
value: 34.013
|
274 |
+
- type: mrr_at_5
|
275 |
+
value: 35.156
|
276 |
+
- type: ndcg_at_1
|
277 |
+
value: 28.662
|
278 |
+
- type: ndcg_at_10
|
279 |
+
value: 35.452
|
280 |
+
- type: ndcg_at_100
|
281 |
+
value: 40.1
|
282 |
+
- type: ndcg_at_1000
|
283 |
+
value: 42.323
|
284 |
+
- type: ndcg_at_3
|
285 |
+
value: 32.112
|
286 |
+
- type: ndcg_at_5
|
287 |
+
value: 33.638
|
288 |
+
- type: precision_at_1
|
289 |
+
value: 28.662
|
290 |
+
- type: precision_at_10
|
291 |
+
value: 6.688
|
292 |
+
- type: precision_at_100
|
293 |
+
value: 1.13
|
294 |
+
- type: precision_at_1000
|
295 |
+
value: 0.16
|
296 |
+
- type: precision_at_3
|
297 |
+
value: 15.562999999999999
|
298 |
+
- type: precision_at_5
|
299 |
+
value: 11.019
|
300 |
+
- type: recall_at_1
|
301 |
+
value: 22.528000000000002
|
302 |
+
- type: recall_at_10
|
303 |
+
value: 43.748
|
304 |
+
- type: recall_at_100
|
305 |
+
value: 64.235
|
306 |
+
- type: recall_at_1000
|
307 |
+
value: 78.609
|
308 |
+
- type: recall_at_3
|
309 |
+
value: 33.937
|
310 |
+
- type: recall_at_5
|
311 |
+
value: 38.234
|
312 |
+
- task:
|
313 |
+
type: Retrieval
|
314 |
+
dataset:
|
315 |
+
type: climate-fever
|
316 |
+
name: MTEB ClimateFEVER
|
317 |
+
config: default
|
318 |
+
split: test
|
319 |
+
revision: None
|
320 |
+
metrics:
|
321 |
+
- type: map_at_1
|
322 |
+
value: 9.468
|
323 |
+
- type: map_at_10
|
324 |
+
value: 16.029
|
325 |
+
- type: map_at_100
|
326 |
+
value: 17.693
|
327 |
+
- type: map_at_1000
|
328 |
+
value: 17.886
|
329 |
+
- type: map_at_3
|
330 |
+
value: 13.15
|
331 |
+
- type: map_at_5
|
332 |
+
value: 14.568
|
333 |
+
- type: mrr_at_1
|
334 |
+
value: 21.173000000000002
|
335 |
+
- type: mrr_at_10
|
336 |
+
value: 31.028
|
337 |
+
- type: mrr_at_100
|
338 |
+
value: 32.061
|
339 |
+
- type: mrr_at_1000
|
340 |
+
value: 32.119
|
341 |
+
- type: mrr_at_3
|
342 |
+
value: 27.534999999999997
|
343 |
+
- type: mrr_at_5
|
344 |
+
value: 29.431
|
345 |
+
- type: ndcg_at_1
|
346 |
+
value: 21.173000000000002
|
347 |
+
- type: ndcg_at_10
|
348 |
+
value: 23.224
|
349 |
+
- type: ndcg_at_100
|
350 |
+
value: 30.225
|
351 |
+
- type: ndcg_at_1000
|
352 |
+
value: 33.961000000000006
|
353 |
+
- type: ndcg_at_3
|
354 |
+
value: 18.174
|
355 |
+
- type: ndcg_at_5
|
356 |
+
value: 19.897000000000002
|
357 |
+
- type: precision_at_1
|
358 |
+
value: 21.173000000000002
|
359 |
+
- type: precision_at_10
|
360 |
+
value: 7.4719999999999995
|
361 |
+
- type: precision_at_100
|
362 |
+
value: 1.5010000000000001
|
363 |
+
- type: precision_at_1000
|
364 |
+
value: 0.219
|
365 |
+
- type: precision_at_3
|
366 |
+
value: 13.312
|
367 |
+
- type: precision_at_5
|
368 |
+
value: 10.619
|
369 |
+
- type: recall_at_1
|
370 |
+
value: 9.468
|
371 |
+
- type: recall_at_10
|
372 |
+
value: 28.823
|
373 |
+
- type: recall_at_100
|
374 |
+
value: 53.26499999999999
|
375 |
+
- type: recall_at_1000
|
376 |
+
value: 74.536
|
377 |
+
- type: recall_at_3
|
378 |
+
value: 16.672
|
379 |
+
- type: recall_at_5
|
380 |
+
value: 21.302
|
381 |
+
- task:
|
382 |
+
type: Retrieval
|
383 |
+
dataset:
|
384 |
+
type: dbpedia-entity
|
385 |
+
name: MTEB DBPedia
|
386 |
+
config: default
|
387 |
+
split: test
|
388 |
+
revision: None
|
389 |
+
metrics:
|
390 |
+
- type: map_at_1
|
391 |
+
value: 6.343
|
392 |
+
- type: map_at_10
|
393 |
+
value: 12.717
|
394 |
+
- type: map_at_100
|
395 |
+
value: 16.48
|
396 |
+
- type: map_at_1000
|
397 |
+
value: 17.381
|
398 |
+
- type: map_at_3
|
399 |
+
value: 9.568999999999999
|
400 |
+
- type: map_at_5
|
401 |
+
value: 11.125
|
402 |
+
- type: mrr_at_1
|
403 |
+
value: 48.75
|
404 |
+
- type: mrr_at_10
|
405 |
+
value: 58.425000000000004
|
406 |
+
- type: mrr_at_100
|
407 |
+
value: 59.075
|
408 |
+
- type: mrr_at_1000
|
409 |
+
value: 59.095
|
410 |
+
- type: mrr_at_3
|
411 |
+
value: 56.291999999999994
|
412 |
+
- type: mrr_at_5
|
413 |
+
value: 57.679
|
414 |
+
- type: ndcg_at_1
|
415 |
+
value: 37.875
|
416 |
+
- type: ndcg_at_10
|
417 |
+
value: 27.77
|
418 |
+
- type: ndcg_at_100
|
419 |
+
value: 30.288999999999998
|
420 |
+
- type: ndcg_at_1000
|
421 |
+
value: 36.187999999999995
|
422 |
+
- type: ndcg_at_3
|
423 |
+
value: 31.385999999999996
|
424 |
+
- type: ndcg_at_5
|
425 |
+
value: 29.923
|
426 |
+
- type: precision_at_1
|
427 |
+
value: 48.75
|
428 |
+
- type: precision_at_10
|
429 |
+
value: 22.375
|
430 |
+
- type: precision_at_100
|
431 |
+
value: 6.3420000000000005
|
432 |
+
- type: precision_at_1000
|
433 |
+
value: 1.4489999999999998
|
434 |
+
- type: precision_at_3
|
435 |
+
value: 35.5
|
436 |
+
- type: precision_at_5
|
437 |
+
value: 30.55
|
438 |
+
- type: recall_at_1
|
439 |
+
value: 6.343
|
440 |
+
- type: recall_at_10
|
441 |
+
value: 16.936
|
442 |
+
- type: recall_at_100
|
443 |
+
value: 35.955999999999996
|
444 |
+
- type: recall_at_1000
|
445 |
+
value: 55.787
|
446 |
+
- type: recall_at_3
|
447 |
+
value: 10.771
|
448 |
+
- type: recall_at_5
|
449 |
+
value: 13.669999999999998
|
450 |
+
- task:
|
451 |
+
type: Classification
|
452 |
+
dataset:
|
453 |
+
type: mteb/emotion
|
454 |
+
name: MTEB EmotionClassification
|
455 |
+
config: default
|
456 |
+
split: test
|
457 |
+
revision: 4f58c6b202a23cf9a4da393831edf4f9183cad37
|
458 |
+
metrics:
|
459 |
+
- type: accuracy
|
460 |
+
value: 41.99
|
461 |
+
- type: f1
|
462 |
+
value: 36.823402174564954
|
463 |
+
- task:
|
464 |
+
type: Retrieval
|
465 |
+
dataset:
|
466 |
+
type: fever
|
467 |
+
name: MTEB FEVER
|
468 |
+
config: default
|
469 |
+
split: test
|
470 |
+
revision: None
|
471 |
+
metrics:
|
472 |
+
- type: map_at_1
|
473 |
+
value: 40.088
|
474 |
+
- type: map_at_10
|
475 |
+
value: 52.69200000000001
|
476 |
+
- type: map_at_100
|
477 |
+
value: 53.296
|
478 |
+
- type: map_at_1000
|
479 |
+
value: 53.325
|
480 |
+
- type: map_at_3
|
481 |
+
value: 49.905
|
482 |
+
- type: map_at_5
|
483 |
+
value: 51.617000000000004
|
484 |
+
- type: mrr_at_1
|
485 |
+
value: 43.009
|
486 |
+
- type: mrr_at_10
|
487 |
+
value: 56.203
|
488 |
+
- type: mrr_at_100
|
489 |
+
value: 56.75
|
490 |
+
- type: mrr_at_1000
|
491 |
+
value: 56.769000000000005
|
492 |
+
- type: mrr_at_3
|
493 |
+
value: 53.400000000000006
|
494 |
+
- type: mrr_at_5
|
495 |
+
value: 55.163
|
496 |
+
- type: ndcg_at_1
|
497 |
+
value: 43.009
|
498 |
+
- type: ndcg_at_10
|
499 |
+
value: 59.39
|
500 |
+
- type: ndcg_at_100
|
501 |
+
value: 62.129999999999995
|
502 |
+
- type: ndcg_at_1000
|
503 |
+
value: 62.793
|
504 |
+
- type: ndcg_at_3
|
505 |
+
value: 53.878
|
506 |
+
- type: ndcg_at_5
|
507 |
+
value: 56.887
|
508 |
+
- type: precision_at_1
|
509 |
+
value: 43.009
|
510 |
+
- type: precision_at_10
|
511 |
+
value: 8.366
|
512 |
+
- type: precision_at_100
|
513 |
+
value: 0.983
|
514 |
+
- type: precision_at_1000
|
515 |
+
value: 0.105
|
516 |
+
- type: precision_at_3
|
517 |
+
value: 22.377
|
518 |
+
- type: precision_at_5
|
519 |
+
value: 15.035000000000002
|
520 |
+
- type: recall_at_1
|
521 |
+
value: 40.088
|
522 |
+
- type: recall_at_10
|
523 |
+
value: 76.68700000000001
|
524 |
+
- type: recall_at_100
|
525 |
+
value: 88.91
|
526 |
+
- type: recall_at_1000
|
527 |
+
value: 93.782
|
528 |
+
- type: recall_at_3
|
529 |
+
value: 61.809999999999995
|
530 |
+
- type: recall_at_5
|
531 |
+
value: 69.131
|
532 |
+
- task:
|
533 |
+
type: Retrieval
|
534 |
+
dataset:
|
535 |
+
type: fiqa
|
536 |
+
name: MTEB FiQA2018
|
537 |
+
config: default
|
538 |
+
split: test
|
539 |
+
revision: None
|
540 |
+
metrics:
|
541 |
+
- type: map_at_1
|
542 |
+
value: 10.817
|
543 |
+
- type: map_at_10
|
544 |
+
value: 18.9
|
545 |
+
- type: map_at_100
|
546 |
+
value: 20.448
|
547 |
+
- type: map_at_1000
|
548 |
+
value: 20.660999999999998
|
549 |
+
- type: map_at_3
|
550 |
+
value: 15.979
|
551 |
+
- type: map_at_5
|
552 |
+
value: 17.415
|
553 |
+
- type: mrr_at_1
|
554 |
+
value: 23.148
|
555 |
+
- type: mrr_at_10
|
556 |
+
value: 31.208000000000002
|
557 |
+
- type: mrr_at_100
|
558 |
+
value: 32.167
|
559 |
+
- type: mrr_at_1000
|
560 |
+
value: 32.242
|
561 |
+
- type: mrr_at_3
|
562 |
+
value: 28.498
|
563 |
+
- type: mrr_at_5
|
564 |
+
value: 29.964000000000002
|
565 |
+
- type: ndcg_at_1
|
566 |
+
value: 23.148
|
567 |
+
- type: ndcg_at_10
|
568 |
+
value: 25.325999999999997
|
569 |
+
- type: ndcg_at_100
|
570 |
+
value: 31.927
|
571 |
+
- type: ndcg_at_1000
|
572 |
+
value: 36.081
|
573 |
+
- type: ndcg_at_3
|
574 |
+
value: 21.647
|
575 |
+
- type: ndcg_at_5
|
576 |
+
value: 22.762999999999998
|
577 |
+
- type: precision_at_1
|
578 |
+
value: 23.148
|
579 |
+
- type: precision_at_10
|
580 |
+
value: 7.546
|
581 |
+
- type: precision_at_100
|
582 |
+
value: 1.415
|
583 |
+
- type: precision_at_1000
|
584 |
+
value: 0.216
|
585 |
+
- type: precision_at_3
|
586 |
+
value: 14.969
|
587 |
+
- type: precision_at_5
|
588 |
+
value: 11.327
|
589 |
+
- type: recall_at_1
|
590 |
+
value: 10.817
|
591 |
+
- type: recall_at_10
|
592 |
+
value: 32.164
|
593 |
+
- type: recall_at_100
|
594 |
+
value: 57.655
|
595 |
+
- type: recall_at_1000
|
596 |
+
value: 82.797
|
597 |
+
- type: recall_at_3
|
598 |
+
value: 19.709
|
599 |
+
- type: recall_at_5
|
600 |
+
value: 24.333
|
601 |
+
- task:
|
602 |
+
type: Retrieval
|
603 |
+
dataset:
|
604 |
+
type: hotpotqa
|
605 |
+
name: MTEB HotpotQA
|
606 |
+
config: default
|
607 |
+
split: test
|
608 |
+
revision: None
|
609 |
+
metrics:
|
610 |
+
- type: map_at_1
|
611 |
+
value: 25.380999999999997
|
612 |
+
- type: map_at_10
|
613 |
+
value: 33.14
|
614 |
+
- type: map_at_100
|
615 |
+
value: 33.948
|
616 |
+
- type: map_at_1000
|
617 |
+
value: 34.028000000000006
|
618 |
+
- type: map_at_3
|
619 |
+
value: 31.019999999999996
|
620 |
+
- type: map_at_5
|
621 |
+
value: 32.23
|
622 |
+
- type: mrr_at_1
|
623 |
+
value: 50.763000000000005
|
624 |
+
- type: mrr_at_10
|
625 |
+
value: 57.899
|
626 |
+
- type: mrr_at_100
|
627 |
+
value: 58.426
|
628 |
+
- type: mrr_at_1000
|
629 |
+
value: 58.457
|
630 |
+
- type: mrr_at_3
|
631 |
+
value: 56.093
|
632 |
+
- type: mrr_at_5
|
633 |
+
value: 57.116
|
634 |
+
- type: ndcg_at_1
|
635 |
+
value: 50.763000000000005
|
636 |
+
- type: ndcg_at_10
|
637 |
+
value: 41.656
|
638 |
+
- type: ndcg_at_100
|
639 |
+
value: 45.079
|
640 |
+
- type: ndcg_at_1000
|
641 |
+
value: 46.916999999999994
|
642 |
+
- type: ndcg_at_3
|
643 |
+
value: 37.834
|
644 |
+
- type: ndcg_at_5
|
645 |
+
value: 39.732
|
646 |
+
- type: precision_at_1
|
647 |
+
value: 50.763000000000005
|
648 |
+
- type: precision_at_10
|
649 |
+
value: 8.648
|
650 |
+
- type: precision_at_100
|
651 |
+
value: 1.135
|
652 |
+
- type: precision_at_1000
|
653 |
+
value: 0.13799999999999998
|
654 |
+
- type: precision_at_3
|
655 |
+
value: 23.105999999999998
|
656 |
+
- type: precision_at_5
|
657 |
+
value: 15.363
|
658 |
+
- type: recall_at_1
|
659 |
+
value: 25.380999999999997
|
660 |
+
- type: recall_at_10
|
661 |
+
value: 43.241
|
662 |
+
- type: recall_at_100
|
663 |
+
value: 56.745000000000005
|
664 |
+
- type: recall_at_1000
|
665 |
+
value: 69.048
|
666 |
+
- type: recall_at_3
|
667 |
+
value: 34.659
|
668 |
+
- type: recall_at_5
|
669 |
+
value: 38.406
|
670 |
+
- task:
|
671 |
+
type: Classification
|
672 |
+
dataset:
|
673 |
+
type: mteb/imdb
|
674 |
+
name: MTEB ImdbClassification
|
675 |
+
config: default
|
676 |
+
split: test
|
677 |
+
revision: 3d86128a09e091d6018b6d26cad27f2739fc2db7
|
678 |
+
metrics:
|
679 |
+
- type: accuracy
|
680 |
+
value: 79.544
|
681 |
+
- type: ap
|
682 |
+
value: 73.82920133396664
|
683 |
+
- type: f1
|
684 |
+
value: 79.51048124883265
|
685 |
+
- task:
|
686 |
+
type: Retrieval
|
687 |
+
dataset:
|
688 |
+
type: msmarco
|
689 |
+
name: MTEB MSMARCO
|
690 |
+
config: default
|
691 |
+
split: dev
|
692 |
+
revision: None
|
693 |
+
metrics:
|
694 |
+
- type: map_at_1
|
695 |
+
value: 11.174000000000001
|
696 |
+
- type: map_at_10
|
697 |
+
value: 19.451999999999998
|
698 |
+
- type: map_at_100
|
699 |
+
value: 20.612
|
700 |
+
- type: map_at_1000
|
701 |
+
value: 20.703
|
702 |
+
- type: map_at_3
|
703 |
+
value: 16.444
|
704 |
+
- type: map_at_5
|
705 |
+
value: 18.083
|
706 |
+
- type: mrr_at_1
|
707 |
+
value: 11.447000000000001
|
708 |
+
- type: mrr_at_10
|
709 |
+
value: 19.808
|
710 |
+
- type: mrr_at_100
|
711 |
+
value: 20.958
|
712 |
+
- type: mrr_at_1000
|
713 |
+
value: 21.041999999999998
|
714 |
+
- type: mrr_at_3
|
715 |
+
value: 16.791
|
716 |
+
- type: mrr_at_5
|
717 |
+
value: 18.459
|
718 |
+
- type: ndcg_at_1
|
719 |
+
value: 11.447000000000001
|
720 |
+
- type: ndcg_at_10
|
721 |
+
value: 24.556
|
722 |
+
- type: ndcg_at_100
|
723 |
+
value: 30.637999999999998
|
724 |
+
- type: ndcg_at_1000
|
725 |
+
value: 33.14
|
726 |
+
- type: ndcg_at_3
|
727 |
+
value: 18.325
|
728 |
+
- type: ndcg_at_5
|
729 |
+
value: 21.278
|
730 |
+
- type: precision_at_1
|
731 |
+
value: 11.447000000000001
|
732 |
+
- type: precision_at_10
|
733 |
+
value: 4.215
|
734 |
+
- type: precision_at_100
|
735 |
+
value: 0.732
|
736 |
+
- type: precision_at_1000
|
737 |
+
value: 0.095
|
738 |
+
- type: precision_at_3
|
739 |
+
value: 8.052
|
740 |
+
- type: precision_at_5
|
741 |
+
value: 6.318
|
742 |
+
- type: recall_at_1
|
743 |
+
value: 11.174000000000001
|
744 |
+
- type: recall_at_10
|
745 |
+
value: 40.543
|
746 |
+
- type: recall_at_100
|
747 |
+
value: 69.699
|
748 |
+
- type: recall_at_1000
|
749 |
+
value: 89.403
|
750 |
+
- type: recall_at_3
|
751 |
+
value: 23.442
|
752 |
+
- type: recall_at_5
|
753 |
+
value: 30.536
|
754 |
+
- task:
|
755 |
+
type: Classification
|
756 |
+
dataset:
|
757 |
+
type: mteb/mtop_domain
|
758 |
+
name: MTEB MTOPDomainClassification (en)
|
759 |
+
config: en
|
760 |
+
split: test
|
761 |
+
revision: d80d48c1eb48d3562165c59d59d0034df9fff0bf
|
762 |
+
metrics:
|
763 |
+
- type: accuracy
|
764 |
+
value: 89.6671226630187
|
765 |
+
- type: f1
|
766 |
+
value: 89.57660424361246
|
767 |
+
- task:
|
768 |
+
type: Classification
|
769 |
+
dataset:
|
770 |
+
type: mteb/mtop_intent
|
771 |
+
name: MTEB MTOPIntentClassification (en)
|
772 |
+
config: en
|
773 |
+
split: test
|
774 |
+
revision: ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba
|
775 |
+
metrics:
|
776 |
+
- type: accuracy
|
777 |
+
value: 60.284997720018254
|
778 |
+
- type: f1
|
779 |
+
value: 40.30637400152823
|
780 |
+
- task:
|
781 |
+
type: Classification
|
782 |
+
dataset:
|
783 |
+
type: mteb/amazon_massive_intent
|
784 |
+
name: MTEB MassiveIntentClassification (en)
|
785 |
+
config: en
|
786 |
+
split: test
|
787 |
+
revision: 31efe3c427b0bae9c22cbb560b8f15491cc6bed7
|
788 |
+
metrics:
|
789 |
+
- type: accuracy
|
790 |
+
value: 63.33557498318763
|
791 |
+
- type: f1
|
792 |
+
value: 60.24039910680179
|
793 |
+
- task:
|
794 |
+
type: Classification
|
795 |
+
dataset:
|
796 |
+
type: mteb/amazon_massive_scenario
|
797 |
+
name: MTEB MassiveScenarioClassification (en)
|
798 |
+
config: en
|
799 |
+
split: test
|
800 |
+
revision: 7d571f92784cd94a019292a1f45445077d0ef634
|
801 |
+
metrics:
|
802 |
+
- type: accuracy
|
803 |
+
value: 72.37390719569603
|
804 |
+
- type: f1
|
805 |
+
value: 72.33097333477316
|
806 |
+
- task:
|
807 |
+
type: Clustering
|
808 |
+
dataset:
|
809 |
+
type: mteb/medrxiv-clustering-p2p
|
810 |
+
name: MTEB MedrxivClusteringP2P
|
811 |
+
config: default
|
812 |
+
split: test
|
813 |
+
revision: e7a26af6f3ae46b30dde8737f02c07b1505bcc73
|
814 |
+
metrics:
|
815 |
+
- type: v_measure
|
816 |
+
value: 34.68158939060552
|
817 |
+
- task:
|
818 |
+
type: Clustering
|
819 |
+
dataset:
|
820 |
+
type: mteb/medrxiv-clustering-s2s
|
821 |
+
name: MTEB MedrxivClusteringS2S
|
822 |
+
config: default
|
823 |
+
split: test
|
824 |
+
revision: 35191c8c0dca72d8ff3efcd72aa802307d469663
|
825 |
+
metrics:
|
826 |
+
- type: v_measure
|
827 |
+
value: 30.340061711905236
|
828 |
+
- task:
|
829 |
+
type: Reranking
|
830 |
+
dataset:
|
831 |
+
type: mteb/mind_small
|
832 |
+
name: MTEB MindSmallReranking
|
833 |
+
config: default
|
834 |
+
split: test
|
835 |
+
revision: 3bdac13927fdc888b903db93b2ffdbd90b295a69
|
836 |
+
metrics:
|
837 |
+
- type: map
|
838 |
+
value: 32.01814326295803
|
839 |
+
- type: mrr
|
840 |
+
value: 33.20555240055367
|
841 |
+
- task:
|
842 |
+
type: Retrieval
|
843 |
+
dataset:
|
844 |
+
type: nfcorpus
|
845 |
+
name: MTEB NFCorpus
|
846 |
+
config: default
|
847 |
+
split: test
|
848 |
+
revision: None
|
849 |
+
metrics:
|
850 |
+
- type: map_at_1
|
851 |
+
value: 3.3910000000000005
|
852 |
+
- type: map_at_10
|
853 |
+
value: 7.7219999999999995
|
854 |
+
- type: map_at_100
|
855 |
+
value: 10.286
|
856 |
+
- type: map_at_1000
|
857 |
+
value: 11.668000000000001
|
858 |
+
- type: map_at_3
|
859 |
+
value: 5.552
|
860 |
+
- type: map_at_5
|
861 |
+
value: 6.468
|
862 |
+
- type: mrr_at_1
|
863 |
+
value: 34.365
|
864 |
+
- type: mrr_at_10
|
865 |
+
value: 42.555
|
866 |
+
- type: mrr_at_100
|
867 |
+
value: 43.295
|
868 |
+
- type: mrr_at_1000
|
869 |
+
value: 43.357
|
870 |
+
- type: mrr_at_3
|
871 |
+
value: 40.299
|
872 |
+
- type: mrr_at_5
|
873 |
+
value: 41.182
|
874 |
+
- type: ndcg_at_1
|
875 |
+
value: 31.424000000000003
|
876 |
+
- type: ndcg_at_10
|
877 |
+
value: 24.758
|
878 |
+
- type: ndcg_at_100
|
879 |
+
value: 23.677999999999997
|
880 |
+
- type: ndcg_at_1000
|
881 |
+
value: 33.377
|
882 |
+
- type: ndcg_at_3
|
883 |
+
value: 28.302
|
884 |
+
- type: ndcg_at_5
|
885 |
+
value: 26.342
|
886 |
+
- type: precision_at_1
|
887 |
+
value: 33.437
|
888 |
+
- type: precision_at_10
|
889 |
+
value: 19.256999999999998
|
890 |
+
- type: precision_at_100
|
891 |
+
value: 6.662999999999999
|
892 |
+
- type: precision_at_1000
|
893 |
+
value: 1.9900000000000002
|
894 |
+
- type: precision_at_3
|
895 |
+
value: 27.761000000000003
|
896 |
+
- type: precision_at_5
|
897 |
+
value: 23.715
|
898 |
+
- type: recall_at_1
|
899 |
+
value: 3.3910000000000005
|
900 |
+
- type: recall_at_10
|
901 |
+
value: 11.068
|
902 |
+
- type: recall_at_100
|
903 |
+
value: 25.878
|
904 |
+
- type: recall_at_1000
|
905 |
+
value: 60.19
|
906 |
+
- type: recall_at_3
|
907 |
+
value: 6.1690000000000005
|
908 |
+
- type: recall_at_5
|
909 |
+
value: 7.767
|
910 |
+
- task:
|
911 |
+
type: Retrieval
|
912 |
+
dataset:
|
913 |
+
type: nq
|
914 |
+
name: MTEB NQ
|
915 |
+
config: default
|
916 |
+
split: test
|
917 |
+
revision: None
|
918 |
+
metrics:
|
919 |
+
- type: map_at_1
|
920 |
+
value: 15.168000000000001
|
921 |
+
- type: map_at_10
|
922 |
+
value: 26.177
|
923 |
+
- type: map_at_100
|
924 |
+
value: 27.564
|
925 |
+
- type: map_at_1000
|
926 |
+
value: 27.628999999999998
|
927 |
+
- type: map_at_3
|
928 |
+
value: 22.03
|
929 |
+
- type: map_at_5
|
930 |
+
value: 24.276
|
931 |
+
- type: mrr_at_1
|
932 |
+
value: 17.439
|
933 |
+
- type: mrr_at_10
|
934 |
+
value: 28.205000000000002
|
935 |
+
- type: mrr_at_100
|
936 |
+
value: 29.357
|
937 |
+
- type: mrr_at_1000
|
938 |
+
value: 29.408
|
939 |
+
- type: mrr_at_3
|
940 |
+
value: 24.377
|
941 |
+
- type: mrr_at_5
|
942 |
+
value: 26.540000000000003
|
943 |
+
- type: ndcg_at_1
|
944 |
+
value: 17.41
|
945 |
+
- type: ndcg_at_10
|
946 |
+
value: 32.936
|
947 |
+
- type: ndcg_at_100
|
948 |
+
value: 39.196999999999996
|
949 |
+
- type: ndcg_at_1000
|
950 |
+
value: 40.892
|
951 |
+
- type: ndcg_at_3
|
952 |
+
value: 24.721
|
953 |
+
- type: ndcg_at_5
|
954 |
+
value: 28.615000000000002
|
955 |
+
- type: precision_at_1
|
956 |
+
value: 17.41
|
957 |
+
- type: precision_at_10
|
958 |
+
value: 6.199000000000001
|
959 |
+
- type: precision_at_100
|
960 |
+
value: 0.9690000000000001
|
961 |
+
- type: precision_at_1000
|
962 |
+
value: 0.11299999999999999
|
963 |
+
- type: precision_at_3
|
964 |
+
value: 11.790000000000001
|
965 |
+
- type: precision_at_5
|
966 |
+
value: 9.264
|
967 |
+
- type: recall_at_1
|
968 |
+
value: 15.168000000000001
|
969 |
+
- type: recall_at_10
|
970 |
+
value: 51.914
|
971 |
+
- type: recall_at_100
|
972 |
+
value: 79.804
|
973 |
+
- type: recall_at_1000
|
974 |
+
value: 92.75999999999999
|
975 |
+
- type: recall_at_3
|
976 |
+
value: 30.212
|
977 |
+
- type: recall_at_5
|
978 |
+
value: 39.204
|
979 |
+
- task:
|
980 |
+
type: Retrieval
|
981 |
+
dataset:
|
982 |
+
type: quora
|
983 |
+
name: MTEB QuoraRetrieval
|
984 |
+
config: default
|
985 |
+
split: test
|
986 |
+
revision: None
|
987 |
+
metrics:
|
988 |
+
- type: map_at_1
|
989 |
+
value: 67.306
|
990 |
+
- type: map_at_10
|
991 |
+
value: 80.634
|
992 |
+
- type: map_at_100
|
993 |
+
value: 81.349
|
994 |
+
- type: map_at_1000
|
995 |
+
value: 81.37299999999999
|
996 |
+
- type: map_at_3
|
997 |
+
value: 77.691
|
998 |
+
- type: map_at_5
|
999 |
+
value: 79.512
|
1000 |
+
- type: mrr_at_1
|
1001 |
+
value: 77.56
|
1002 |
+
- type: mrr_at_10
|
1003 |
+
value: 84.177
|
1004 |
+
- type: mrr_at_100
|
1005 |
+
value: 84.35000000000001
|
1006 |
+
- type: mrr_at_1000
|
1007 |
+
value: 84.353
|
1008 |
+
- type: mrr_at_3
|
1009 |
+
value: 83.003
|
1010 |
+
- type: mrr_at_5
|
1011 |
+
value: 83.799
|
1012 |
+
- type: ndcg_at_1
|
1013 |
+
value: 77.58
|
1014 |
+
- type: ndcg_at_10
|
1015 |
+
value: 84.782
|
1016 |
+
- type: ndcg_at_100
|
1017 |
+
value: 86.443
|
1018 |
+
- type: ndcg_at_1000
|
1019 |
+
value: 86.654
|
1020 |
+
- type: ndcg_at_3
|
1021 |
+
value: 81.67
|
1022 |
+
- type: ndcg_at_5
|
1023 |
+
value: 83.356
|
1024 |
+
- type: precision_at_1
|
1025 |
+
value: 77.58
|
1026 |
+
- type: precision_at_10
|
1027 |
+
value: 12.875
|
1028 |
+
- type: precision_at_100
|
1029 |
+
value: 1.503
|
1030 |
+
- type: precision_at_1000
|
1031 |
+
value: 0.156
|
1032 |
+
- type: precision_at_3
|
1033 |
+
value: 35.63
|
1034 |
+
- type: precision_at_5
|
1035 |
+
value: 23.483999999999998
|
1036 |
+
- type: recall_at_1
|
1037 |
+
value: 67.306
|
1038 |
+
- type: recall_at_10
|
1039 |
+
value: 92.64
|
1040 |
+
- type: recall_at_100
|
1041 |
+
value: 98.681
|
1042 |
+
- type: recall_at_1000
|
1043 |
+
value: 99.79
|
1044 |
+
- type: recall_at_3
|
1045 |
+
value: 83.682
|
1046 |
+
- type: recall_at_5
|
1047 |
+
value: 88.424
|
1048 |
+
- task:
|
1049 |
+
type: Clustering
|
1050 |
+
dataset:
|
1051 |
+
type: mteb/reddit-clustering
|
1052 |
+
name: MTEB RedditClustering
|
1053 |
+
config: default
|
1054 |
+
split: test
|
1055 |
+
revision: 24640382cdbf8abc73003fb0fa6d111a705499eb
|
1056 |
+
metrics:
|
1057 |
+
- type: v_measure
|
1058 |
+
value: 50.76319866126382
|
1059 |
+
- task:
|
1060 |
+
type: Clustering
|
1061 |
+
dataset:
|
1062 |
+
type: mteb/reddit-clustering-p2p
|
1063 |
+
name: MTEB RedditClusteringP2P
|
1064 |
+
config: default
|
1065 |
+
split: test
|
1066 |
+
revision: 282350215ef01743dc01b456c7f5241fa8937f16
|
1067 |
+
metrics:
|
1068 |
+
- type: v_measure
|
1069 |
+
value: 55.024711941648995
|
1070 |
+
- task:
|
1071 |
+
type: Retrieval
|
1072 |
+
dataset:
|
1073 |
+
type: scidocs
|
1074 |
+
name: MTEB SCIDOCS
|
1075 |
+
config: default
|
1076 |
+
split: test
|
1077 |
+
revision: None
|
1078 |
+
metrics:
|
1079 |
+
- type: map_at_1
|
1080 |
+
value: 3.9379999999999997
|
1081 |
+
- type: map_at_10
|
1082 |
+
value: 8.817
|
1083 |
+
- type: map_at_100
|
1084 |
+
value: 10.546999999999999
|
1085 |
+
- type: map_at_1000
|
1086 |
+
value: 10.852
|
1087 |
+
- type: map_at_3
|
1088 |
+
value: 6.351999999999999
|
1089 |
+
- type: map_at_5
|
1090 |
+
value: 7.453
|
1091 |
+
- type: mrr_at_1
|
1092 |
+
value: 19.400000000000002
|
1093 |
+
- type: mrr_at_10
|
1094 |
+
value: 27.371000000000002
|
1095 |
+
- type: mrr_at_100
|
1096 |
+
value: 28.671999999999997
|
1097 |
+
- type: mrr_at_1000
|
1098 |
+
value: 28.747
|
1099 |
+
- type: mrr_at_3
|
1100 |
+
value: 24.583
|
1101 |
+
- type: mrr_at_5
|
1102 |
+
value: 26.143
|
1103 |
+
- type: ndcg_at_1
|
1104 |
+
value: 19.400000000000002
|
1105 |
+
- type: ndcg_at_10
|
1106 |
+
value: 15.264
|
1107 |
+
- type: ndcg_at_100
|
1108 |
+
value: 22.63
|
1109 |
+
- type: ndcg_at_1000
|
1110 |
+
value: 28.559
|
1111 |
+
- type: ndcg_at_3
|
1112 |
+
value: 14.424999999999999
|
1113 |
+
- type: ndcg_at_5
|
1114 |
+
value: 12.520000000000001
|
1115 |
+
- type: precision_at_1
|
1116 |
+
value: 19.400000000000002
|
1117 |
+
- type: precision_at_10
|
1118 |
+
value: 7.8100000000000005
|
1119 |
+
- type: precision_at_100
|
1120 |
+
value: 1.854
|
1121 |
+
- type: precision_at_1000
|
1122 |
+
value: 0.329
|
1123 |
+
- type: precision_at_3
|
1124 |
+
value: 13.100000000000001
|
1125 |
+
- type: precision_at_5
|
1126 |
+
value: 10.68
|
1127 |
+
- type: recall_at_1
|
1128 |
+
value: 3.9379999999999997
|
1129 |
+
- type: recall_at_10
|
1130 |
+
value: 15.903
|
1131 |
+
- type: recall_at_100
|
1132 |
+
value: 37.645
|
1133 |
+
- type: recall_at_1000
|
1134 |
+
value: 66.86
|
1135 |
+
- type: recall_at_3
|
1136 |
+
value: 7.993
|
1137 |
+
- type: recall_at_5
|
1138 |
+
value: 10.885
|
1139 |
+
- task:
|
1140 |
+
type: STS
|
1141 |
+
dataset:
|
1142 |
+
type: mteb/sickr-sts
|
1143 |
+
name: MTEB SICK-R
|
1144 |
+
config: default
|
1145 |
+
split: test
|
1146 |
+
revision: a6ea5a8cab320b040a23452cc28066d9beae2cee
|
1147 |
+
metrics:
|
1148 |
+
- type: cos_sim_pearson
|
1149 |
+
value: 80.12689060151425
|
1150 |
+
- type: cos_sim_spearman
|
1151 |
+
value: 70.46515535094771
|
1152 |
+
- type: euclidean_pearson
|
1153 |
+
value: 77.17160003557223
|
1154 |
+
- type: euclidean_spearman
|
1155 |
+
value: 70.4651757047438
|
1156 |
+
- type: manhattan_pearson
|
1157 |
+
value: 77.18129609281937
|
1158 |
+
- type: manhattan_spearman
|
1159 |
+
value: 70.46610403752913
|
1160 |
+
- task:
|
1161 |
+
type: STS
|
1162 |
+
dataset:
|
1163 |
+
type: mteb/sts12-sts
|
1164 |
+
name: MTEB STS12
|
1165 |
+
config: default
|
1166 |
+
split: test
|
1167 |
+
revision: a0d554a64d88156834ff5ae9920b964011b16384
|
1168 |
+
metrics:
|
1169 |
+
- type: cos_sim_pearson
|
1170 |
+
value: 70.451157033355
|
1171 |
+
- type: cos_sim_spearman
|
1172 |
+
value: 63.99899601697852
|
1173 |
+
- type: euclidean_pearson
|
1174 |
+
value: 67.46985359967678
|
1175 |
+
- type: euclidean_spearman
|
1176 |
+
value: 64.00001637764805
|
1177 |
+
- type: manhattan_pearson
|
1178 |
+
value: 67.56534741780037
|
1179 |
+
- type: manhattan_spearman
|
1180 |
+
value: 64.06533893575366
|
1181 |
+
- task:
|
1182 |
+
type: STS
|
1183 |
+
dataset:
|
1184 |
+
type: mteb/sts13-sts
|
1185 |
+
name: MTEB STS13
|
1186 |
+
config: default
|
1187 |
+
split: test
|
1188 |
+
revision: 7e90230a92c190f1bf69ae9002b8cea547a64cca
|
1189 |
+
metrics:
|
1190 |
+
- type: cos_sim_pearson
|
1191 |
+
value: 77.65086614464292
|
1192 |
+
- type: cos_sim_spearman
|
1193 |
+
value: 78.20169706921848
|
1194 |
+
- type: euclidean_pearson
|
1195 |
+
value: 77.77758172155283
|
1196 |
+
- type: euclidean_spearman
|
1197 |
+
value: 78.20169706921848
|
1198 |
+
- type: manhattan_pearson
|
1199 |
+
value: 77.75077884860052
|
1200 |
+
- type: manhattan_spearman
|
1201 |
+
value: 78.16875216484164
|
1202 |
+
- task:
|
1203 |
+
type: STS
|
1204 |
+
dataset:
|
1205 |
+
type: mteb/sts14-sts
|
1206 |
+
name: MTEB STS14
|
1207 |
+
config: default
|
1208 |
+
split: test
|
1209 |
+
revision: 6031580fec1f6af667f0bd2da0a551cf4f0b2375
|
1210 |
+
metrics:
|
1211 |
+
- type: cos_sim_pearson
|
1212 |
+
value: 76.26381598259717
|
1213 |
+
- type: cos_sim_spearman
|
1214 |
+
value: 70.78377709313477
|
1215 |
+
- type: euclidean_pearson
|
1216 |
+
value: 74.82646556532096
|
1217 |
+
- type: euclidean_spearman
|
1218 |
+
value: 70.78377658155212
|
1219 |
+
- type: manhattan_pearson
|
1220 |
+
value: 74.81784766108225
|
1221 |
+
- type: manhattan_spearman
|
1222 |
+
value: 70.79351454692176
|
1223 |
+
- task:
|
1224 |
+
type: STS
|
1225 |
+
dataset:
|
1226 |
+
type: mteb/sts15-sts
|
1227 |
+
name: MTEB STS15
|
1228 |
+
config: default
|
1229 |
+
split: test
|
1230 |
+
revision: ae752c7c21bf194d8b67fd573edf7ae58183cbe3
|
1231 |
+
metrics:
|
1232 |
+
- type: cos_sim_pearson
|
1233 |
+
value: 79.00532026789739
|
1234 |
+
- type: cos_sim_spearman
|
1235 |
+
value: 80.02708383244838
|
1236 |
+
- type: euclidean_pearson
|
1237 |
+
value: 79.48345422610525
|
1238 |
+
- type: euclidean_spearman
|
1239 |
+
value: 80.02708383244838
|
1240 |
+
- type: manhattan_pearson
|
1241 |
+
value: 79.44519739854803
|
1242 |
+
- type: manhattan_spearman
|
1243 |
+
value: 79.98344094559687
|
1244 |
+
- task:
|
1245 |
+
type: STS
|
1246 |
+
dataset:
|
1247 |
+
type: mteb/sts16-sts
|
1248 |
+
name: MTEB STS16
|
1249 |
+
config: default
|
1250 |
+
split: test
|
1251 |
+
revision: 4d8694f8f0e0100860b497b999b3dbed754a0513
|
1252 |
+
metrics:
|
1253 |
+
- type: cos_sim_pearson
|
1254 |
+
value: 77.32783048164805
|
1255 |
+
- type: cos_sim_spearman
|
1256 |
+
value: 78.79729961288045
|
1257 |
+
- type: euclidean_pearson
|
1258 |
+
value: 78.72111945793154
|
1259 |
+
- type: euclidean_spearman
|
1260 |
+
value: 78.79729904606872
|
1261 |
+
- type: manhattan_pearson
|
1262 |
+
value: 78.72464311117116
|
1263 |
+
- type: manhattan_spearman
|
1264 |
+
value: 78.822591248334
|
1265 |
+
- task:
|
1266 |
+
type: STS
|
1267 |
+
dataset:
|
1268 |
+
type: mteb/sts17-crosslingual-sts
|
1269 |
+
name: MTEB STS17 (en-en)
|
1270 |
+
config: en-en
|
1271 |
+
split: test
|
1272 |
+
revision: af5e6fb845001ecf41f4c1e033ce921939a2a68d
|
1273 |
+
metrics:
|
1274 |
+
- type: cos_sim_pearson
|
1275 |
+
value: 82.04318630630854
|
1276 |
+
- type: cos_sim_spearman
|
1277 |
+
value: 83.87886389259836
|
1278 |
+
- type: euclidean_pearson
|
1279 |
+
value: 83.40385877895086
|
1280 |
+
- type: euclidean_spearman
|
1281 |
+
value: 83.87886389259836
|
1282 |
+
- type: manhattan_pearson
|
1283 |
+
value: 83.46337128901547
|
1284 |
+
- type: manhattan_spearman
|
1285 |
+
value: 83.9723106941644
|
1286 |
+
- task:
|
1287 |
+
type: STS
|
1288 |
+
dataset:
|
1289 |
+
type: mteb/sts22-crosslingual-sts
|
1290 |
+
name: MTEB STS22 (en)
|
1291 |
+
config: en
|
1292 |
+
split: test
|
1293 |
+
revision: 6d1ba47164174a496b7fa5d3569dae26a6813b80
|
1294 |
+
metrics:
|
1295 |
+
- type: cos_sim_pearson
|
1296 |
+
value: 63.003511169944595
|
1297 |
+
- type: cos_sim_spearman
|
1298 |
+
value: 64.39318805580227
|
1299 |
+
- type: euclidean_pearson
|
1300 |
+
value: 65.4797990735967
|
1301 |
+
- type: euclidean_spearman
|
1302 |
+
value: 64.39318805580227
|
1303 |
+
- type: manhattan_pearson
|
1304 |
+
value: 65.44604544280844
|
1305 |
+
- type: manhattan_spearman
|
1306 |
+
value: 64.38742899984233
|
1307 |
+
- task:
|
1308 |
+
type: STS
|
1309 |
+
dataset:
|
1310 |
+
type: mteb/stsbenchmark-sts
|
1311 |
+
name: MTEB STSBenchmark
|
1312 |
+
config: default
|
1313 |
+
split: test
|
1314 |
+
revision: b0fddb56ed78048fa8b90373c8a3cfc37b684831
|
1315 |
+
metrics:
|
1316 |
+
- type: cos_sim_pearson
|
1317 |
+
value: 76.63101237585029
|
1318 |
+
- type: cos_sim_spearman
|
1319 |
+
value: 75.57446967644269
|
1320 |
+
- type: euclidean_pearson
|
1321 |
+
value: 76.93491768734478
|
1322 |
+
- type: euclidean_spearman
|
1323 |
+
value: 75.57446967644269
|
1324 |
+
- type: manhattan_pearson
|
1325 |
+
value: 76.92187567800636
|
1326 |
+
- type: manhattan_spearman
|
1327 |
+
value: 75.57239337194585
|
1328 |
+
- task:
|
1329 |
+
type: Reranking
|
1330 |
+
dataset:
|
1331 |
+
type: mteb/scidocs-reranking
|
1332 |
+
name: MTEB SciDocsRR
|
1333 |
+
config: default
|
1334 |
+
split: test
|
1335 |
+
revision: d3c5e1fc0b855ab6097bf1cda04dd73947d7caab
|
1336 |
+
metrics:
|
1337 |
+
- type: map
|
1338 |
+
value: 78.5376604868993
|
1339 |
+
- type: mrr
|
1340 |
+
value: 92.94422897364073
|
1341 |
+
- task:
|
1342 |
+
type: Retrieval
|
1343 |
+
dataset:
|
1344 |
+
type: scifact
|
1345 |
+
name: MTEB SciFact
|
1346 |
+
config: default
|
1347 |
+
split: test
|
1348 |
+
revision: None
|
1349 |
+
metrics:
|
1350 |
+
- type: map_at_1
|
1351 |
+
value: 38.872
|
1352 |
+
- type: map_at_10
|
1353 |
+
value: 50.417
|
1354 |
+
- type: map_at_100
|
1355 |
+
value: 51.202000000000005
|
1356 |
+
- type: map_at_1000
|
1357 |
+
value: 51.25999999999999
|
1358 |
+
- type: map_at_3
|
1359 |
+
value: 47.02
|
1360 |
+
- type: map_at_5
|
1361 |
+
value: 49.326
|
1362 |
+
- type: mrr_at_1
|
1363 |
+
value: 41.0
|
1364 |
+
- type: mrr_at_10
|
1365 |
+
value: 51.674
|
1366 |
+
- type: mrr_at_100
|
1367 |
+
value: 52.32599999999999
|
1368 |
+
- type: mrr_at_1000
|
1369 |
+
value: 52.376999999999995
|
1370 |
+
- type: mrr_at_3
|
1371 |
+
value: 48.778
|
1372 |
+
- type: mrr_at_5
|
1373 |
+
value: 50.744
|
1374 |
+
- type: ndcg_at_1
|
1375 |
+
value: 41.0
|
1376 |
+
- type: ndcg_at_10
|
1377 |
+
value: 56.027
|
1378 |
+
- type: ndcg_at_100
|
1379 |
+
value: 59.362
|
1380 |
+
- type: ndcg_at_1000
|
1381 |
+
value: 60.839
|
1382 |
+
- type: ndcg_at_3
|
1383 |
+
value: 50.019999999999996
|
1384 |
+
- type: ndcg_at_5
|
1385 |
+
value: 53.644999999999996
|
1386 |
+
- type: precision_at_1
|
1387 |
+
value: 41.0
|
1388 |
+
- type: precision_at_10
|
1389 |
+
value: 8.1
|
1390 |
+
- type: precision_at_100
|
1391 |
+
value: 0.987
|
1392 |
+
- type: precision_at_1000
|
1393 |
+
value: 0.11100000000000002
|
1394 |
+
- type: precision_at_3
|
1395 |
+
value: 20.444000000000003
|
1396 |
+
- type: precision_at_5
|
1397 |
+
value: 14.466999999999999
|
1398 |
+
- type: recall_at_1
|
1399 |
+
value: 38.872
|
1400 |
+
- type: recall_at_10
|
1401 |
+
value: 71.906
|
1402 |
+
- type: recall_at_100
|
1403 |
+
value: 86.367
|
1404 |
+
- type: recall_at_1000
|
1405 |
+
value: 98.0
|
1406 |
+
- type: recall_at_3
|
1407 |
+
value: 56.206
|
1408 |
+
- type: recall_at_5
|
1409 |
+
value: 65.05
|
1410 |
+
- task:
|
1411 |
+
type: PairClassification
|
1412 |
+
dataset:
|
1413 |
+
type: mteb/sprintduplicatequestions-pairclassification
|
1414 |
+
name: MTEB SprintDuplicateQuestions
|
1415 |
+
config: default
|
1416 |
+
split: test
|
1417 |
+
revision: d66bd1f72af766a5cc4b0ca5e00c162f89e8cc46
|
1418 |
+
metrics:
|
1419 |
+
- type: cos_sim_accuracy
|
1420 |
+
value: 99.7039603960396
|
1421 |
+
- type: cos_sim_ap
|
1422 |
+
value: 90.40809844250262
|
1423 |
+
- type: cos_sim_f1
|
1424 |
+
value: 84.53181583031557
|
1425 |
+
- type: cos_sim_precision
|
1426 |
+
value: 87.56698821007502
|
1427 |
+
- type: cos_sim_recall
|
1428 |
+
value: 81.69999999999999
|
1429 |
+
- type: dot_accuracy
|
1430 |
+
value: 99.7039603960396
|
1431 |
+
- type: dot_ap
|
1432 |
+
value: 90.40809844250262
|
1433 |
+
- type: dot_f1
|
1434 |
+
value: 84.53181583031557
|
1435 |
+
- type: dot_precision
|
1436 |
+
value: 87.56698821007502
|
1437 |
+
- type: dot_recall
|
1438 |
+
value: 81.69999999999999
|
1439 |
+
- type: euclidean_accuracy
|
1440 |
+
value: 99.7039603960396
|
1441 |
+
- type: euclidean_ap
|
1442 |
+
value: 90.4080982863383
|
1443 |
+
- type: euclidean_f1
|
1444 |
+
value: 84.53181583031557
|
1445 |
+
- type: euclidean_precision
|
1446 |
+
value: 87.56698821007502
|
1447 |
+
- type: euclidean_recall
|
1448 |
+
value: 81.69999999999999
|
1449 |
+
- type: manhattan_accuracy
|
1450 |
+
value: 99.7
|
1451 |
+
- type: manhattan_ap
|
1452 |
+
value: 90.39771161966652
|
1453 |
+
- type: manhattan_f1
|
1454 |
+
value: 84.32989690721648
|
1455 |
+
- type: manhattan_precision
|
1456 |
+
value: 87.02127659574468
|
1457 |
+
- type: manhattan_recall
|
1458 |
+
value: 81.8
|
1459 |
+
- type: max_accuracy
|
1460 |
+
value: 99.7039603960396
|
1461 |
+
- type: max_ap
|
1462 |
+
value: 90.40809844250262
|
1463 |
+
- type: max_f1
|
1464 |
+
value: 84.53181583031557
|
1465 |
+
- task:
|
1466 |
+
type: Clustering
|
1467 |
+
dataset:
|
1468 |
+
type: mteb/stackexchange-clustering
|
1469 |
+
name: MTEB StackExchangeClustering
|
1470 |
+
config: default
|
1471 |
+
split: test
|
1472 |
+
revision: 6cbc1f7b2bc0622f2e39d2c77fa502909748c259
|
1473 |
+
metrics:
|
1474 |
+
- type: v_measure
|
1475 |
+
value: 59.663210666678715
|
1476 |
+
- task:
|
1477 |
+
type: Clustering
|
1478 |
+
dataset:
|
1479 |
+
type: mteb/stackexchange-clustering-p2p
|
1480 |
+
name: MTEB StackExchangeClusteringP2P
|
1481 |
+
config: default
|
1482 |
+
split: test
|
1483 |
+
revision: 815ca46b2622cec33ccafc3735d572c266efdb44
|
1484 |
+
metrics:
|
1485 |
+
- type: v_measure
|
1486 |
+
value: 32.107791216468776
|
1487 |
+
- task:
|
1488 |
+
type: Reranking
|
1489 |
+
dataset:
|
1490 |
+
type: mteb/stackoverflowdupquestions-reranking
|
1491 |
+
name: MTEB StackOverflowDupQuestions
|
1492 |
+
config: default
|
1493 |
+
split: test
|
1494 |
+
revision: e185fbe320c72810689fc5848eb6114e1ef5ec69
|
1495 |
+
metrics:
|
1496 |
+
- type: map
|
1497 |
+
value: 46.440691925067604
|
1498 |
+
- type: mrr
|
1499 |
+
value: 47.03390257618199
|
1500 |
+
- task:
|
1501 |
+
type: Summarization
|
1502 |
+
dataset:
|
1503 |
+
type: mteb/summeval
|
1504 |
+
name: MTEB SummEval
|
1505 |
+
config: default
|
1506 |
+
split: test
|
1507 |
+
revision: cda12ad7615edc362dbf25a00fdd61d3b1eaf93c
|
1508 |
+
metrics:
|
1509 |
+
- type: cos_sim_pearson
|
1510 |
+
value: 31.067177519784074
|
1511 |
+
- type: cos_sim_spearman
|
1512 |
+
value: 31.234728424648967
|
1513 |
+
- type: dot_pearson
|
1514 |
+
value: 31.06717083018107
|
1515 |
+
- type: dot_spearman
|
1516 |
+
value: 31.234728424648967
|
1517 |
+
- task:
|
1518 |
+
type: Retrieval
|
1519 |
+
dataset:
|
1520 |
+
type: trec-covid
|
1521 |
+
name: MTEB TRECCOVID
|
1522 |
+
config: default
|
1523 |
+
split: test
|
1524 |
+
revision: None
|
1525 |
+
metrics:
|
1526 |
+
- type: map_at_1
|
1527 |
+
value: 0.136
|
1528 |
+
- type: map_at_10
|
1529 |
+
value: 0.767
|
1530 |
+
- type: map_at_100
|
1531 |
+
value: 3.3689999999999998
|
1532 |
+
- type: map_at_1000
|
1533 |
+
value: 8.613999999999999
|
1534 |
+
- type: map_at_3
|
1535 |
+
value: 0.369
|
1536 |
+
- type: map_at_5
|
1537 |
+
value: 0.514
|
1538 |
+
- type: mrr_at_1
|
1539 |
+
value: 48.0
|
1540 |
+
- type: mrr_at_10
|
1541 |
+
value: 63.908
|
1542 |
+
- type: mrr_at_100
|
1543 |
+
value: 64.615
|
1544 |
+
- type: mrr_at_1000
|
1545 |
+
value: 64.615
|
1546 |
+
- type: mrr_at_3
|
1547 |
+
value: 62.0
|
1548 |
+
- type: mrr_at_5
|
1549 |
+
value: 63.4
|
1550 |
+
- type: ndcg_at_1
|
1551 |
+
value: 44.0
|
1552 |
+
- type: ndcg_at_10
|
1553 |
+
value: 38.579
|
1554 |
+
- type: ndcg_at_100
|
1555 |
+
value: 26.409
|
1556 |
+
- type: ndcg_at_1000
|
1557 |
+
value: 26.858999999999998
|
1558 |
+
- type: ndcg_at_3
|
1559 |
+
value: 47.134
|
1560 |
+
- type: ndcg_at_5
|
1561 |
+
value: 43.287
|
1562 |
+
- type: precision_at_1
|
1563 |
+
value: 48.0
|
1564 |
+
- type: precision_at_10
|
1565 |
+
value: 40.400000000000006
|
1566 |
+
- type: precision_at_100
|
1567 |
+
value: 26.640000000000004
|
1568 |
+
- type: precision_at_1000
|
1569 |
+
value: 12.04
|
1570 |
+
- type: precision_at_3
|
1571 |
+
value: 52.666999999999994
|
1572 |
+
- type: precision_at_5
|
1573 |
+
value: 46.800000000000004
|
1574 |
+
- type: recall_at_1
|
1575 |
+
value: 0.136
|
1576 |
+
- type: recall_at_10
|
1577 |
+
value: 1.0070000000000001
|
1578 |
+
- type: recall_at_100
|
1579 |
+
value: 6.318
|
1580 |
+
- type: recall_at_1000
|
1581 |
+
value: 26.522000000000002
|
1582 |
+
- type: recall_at_3
|
1583 |
+
value: 0.41700000000000004
|
1584 |
+
- type: recall_at_5
|
1585 |
+
value: 0.606
|
1586 |
+
- task:
|
1587 |
+
type: Retrieval
|
1588 |
+
dataset:
|
1589 |
+
type: webis-touche2020
|
1590 |
+
name: MTEB Touche2020
|
1591 |
+
config: default
|
1592 |
+
split: test
|
1593 |
+
revision: None
|
1594 |
+
metrics:
|
1595 |
+
- type: map_at_1
|
1596 |
+
value: 1.9949999999999999
|
1597 |
+
- type: map_at_10
|
1598 |
+
value: 8.304
|
1599 |
+
- type: map_at_100
|
1600 |
+
value: 13.644
|
1601 |
+
- type: map_at_1000
|
1602 |
+
value: 15.43
|
1603 |
+
- type: map_at_3
|
1604 |
+
value: 4.788
|
1605 |
+
- type: map_at_5
|
1606 |
+
value: 6.22
|
1607 |
+
- type: mrr_at_1
|
1608 |
+
value: 22.448999999999998
|
1609 |
+
- type: mrr_at_10
|
1610 |
+
value: 37.658
|
1611 |
+
- type: mrr_at_100
|
1612 |
+
value: 38.491
|
1613 |
+
- type: mrr_at_1000
|
1614 |
+
value: 38.503
|
1615 |
+
- type: mrr_at_3
|
1616 |
+
value: 32.312999999999995
|
1617 |
+
- type: mrr_at_5
|
1618 |
+
value: 35.68
|
1619 |
+
- type: ndcg_at_1
|
1620 |
+
value: 21.429000000000002
|
1621 |
+
- type: ndcg_at_10
|
1622 |
+
value: 18.995
|
1623 |
+
- type: ndcg_at_100
|
1624 |
+
value: 32.029999999999994
|
1625 |
+
- type: ndcg_at_1000
|
1626 |
+
value: 44.852
|
1627 |
+
- type: ndcg_at_3
|
1628 |
+
value: 19.464000000000002
|
1629 |
+
- type: ndcg_at_5
|
1630 |
+
value: 19.172
|
1631 |
+
- type: precision_at_1
|
1632 |
+
value: 22.448999999999998
|
1633 |
+
- type: precision_at_10
|
1634 |
+
value: 17.143
|
1635 |
+
- type: precision_at_100
|
1636 |
+
value: 6.877999999999999
|
1637 |
+
- type: precision_at_1000
|
1638 |
+
value: 1.524
|
1639 |
+
- type: precision_at_3
|
1640 |
+
value: 21.769
|
1641 |
+
- type: precision_at_5
|
1642 |
+
value: 20.0
|
1643 |
+
- type: recall_at_1
|
1644 |
+
value: 1.9949999999999999
|
1645 |
+
- type: recall_at_10
|
1646 |
+
value: 13.395999999999999
|
1647 |
+
- type: recall_at_100
|
1648 |
+
value: 44.348
|
1649 |
+
- type: recall_at_1000
|
1650 |
+
value: 82.622
|
1651 |
+
- type: recall_at_3
|
1652 |
+
value: 5.896
|
1653 |
+
- type: recall_at_5
|
1654 |
+
value: 8.554
|
1655 |
+
- task:
|
1656 |
+
type: Classification
|
1657 |
+
dataset:
|
1658 |
+
type: mteb/toxic_conversations_50k
|
1659 |
+
name: MTEB ToxicConversationsClassification
|
1660 |
+
config: default
|
1661 |
+
split: test
|
1662 |
+
revision: d7c0de2777da35d6aae2200a62c6e0e5af397c4c
|
1663 |
+
metrics:
|
1664 |
+
- type: accuracy
|
1665 |
+
value: 67.9394
|
1666 |
+
- type: ap
|
1667 |
+
value: 12.943337263423334
|
1668 |
+
- type: f1
|
1669 |
+
value: 52.28243093094156
|
1670 |
+
- task:
|
1671 |
+
type: Classification
|
1672 |
+
dataset:
|
1673 |
+
type: mteb/tweet_sentiment_extraction
|
1674 |
+
name: MTEB TweetSentimentExtractionClassification
|
1675 |
+
config: default
|
1676 |
+
split: test
|
1677 |
+
revision: d604517c81ca91fe16a244d1248fc021f9ecee7a
|
1678 |
+
metrics:
|
1679 |
+
- type: accuracy
|
1680 |
+
value: 56.414827391058296
|
1681 |
+
- type: f1
|
1682 |
+
value: 56.666412409573105
|
1683 |
+
- task:
|
1684 |
+
type: Clustering
|
1685 |
+
dataset:
|
1686 |
+
type: mteb/twentynewsgroups-clustering
|
1687 |
+
name: MTEB TwentyNewsgroupsClustering
|
1688 |
+
config: default
|
1689 |
+
split: test
|
1690 |
+
revision: 6125ec4e24fa026cec8a478383ee943acfbd5449
|
1691 |
+
metrics:
|
1692 |
+
- type: v_measure
|
1693 |
+
value: 47.009746255495465
|
1694 |
+
- task:
|
1695 |
+
type: PairClassification
|
1696 |
+
dataset:
|
1697 |
+
type: mteb/twittersemeval2015-pairclassification
|
1698 |
+
name: MTEB TwitterSemEval2015
|
1699 |
+
config: default
|
1700 |
+
split: test
|
1701 |
+
revision: 70970daeab8776df92f5ea462b6173c0b46fd2d1
|
1702 |
+
metrics:
|
1703 |
+
- type: cos_sim_accuracy
|
1704 |
+
value: 84.02574953805807
|
1705 |
+
- type: cos_sim_ap
|
1706 |
+
value: 67.66599910763128
|
1707 |
+
- type: cos_sim_f1
|
1708 |
+
value: 63.491277990844985
|
1709 |
+
- type: cos_sim_precision
|
1710 |
+
value: 59.77172140694154
|
1711 |
+
- type: cos_sim_recall
|
1712 |
+
value: 67.70448548812665
|
1713 |
+
- type: dot_accuracy
|
1714 |
+
value: 84.02574953805807
|
1715 |
+
- type: dot_ap
|
1716 |
+
value: 67.66600090945406
|
1717 |
+
- type: dot_f1
|
1718 |
+
value: 63.491277990844985
|
1719 |
+
- type: dot_precision
|
1720 |
+
value: 59.77172140694154
|
1721 |
+
- type: dot_recall
|
1722 |
+
value: 67.70448548812665
|
1723 |
+
- type: euclidean_accuracy
|
1724 |
+
value: 84.02574953805807
|
1725 |
+
- type: euclidean_ap
|
1726 |
+
value: 67.6659842364448
|
1727 |
+
- type: euclidean_f1
|
1728 |
+
value: 63.491277990844985
|
1729 |
+
- type: euclidean_precision
|
1730 |
+
value: 59.77172140694154
|
1731 |
+
- type: euclidean_recall
|
1732 |
+
value: 67.70448548812665
|
1733 |
+
- type: manhattan_accuracy
|
1734 |
+
value: 84.0317100792752
|
1735 |
+
- type: manhattan_ap
|
1736 |
+
value: 67.66351692448987
|
1737 |
+
- type: manhattan_f1
|
1738 |
+
value: 63.48610948306178
|
1739 |
+
- type: manhattan_precision
|
1740 |
+
value: 57.11875131828729
|
1741 |
+
- type: manhattan_recall
|
1742 |
+
value: 71.45118733509234
|
1743 |
+
- type: max_accuracy
|
1744 |
+
value: 84.0317100792752
|
1745 |
+
- type: max_ap
|
1746 |
+
value: 67.66600090945406
|
1747 |
+
- type: max_f1
|
1748 |
+
value: 63.491277990844985
|
1749 |
+
- task:
|
1750 |
+
type: PairClassification
|
1751 |
+
dataset:
|
1752 |
+
type: mteb/twitterurlcorpus-pairclassification
|
1753 |
+
name: MTEB TwitterURLCorpus
|
1754 |
+
config: default
|
1755 |
+
split: test
|
1756 |
+
revision: 8b6510b0b1fa4e4c4f879467980e9be563ec1cdf
|
1757 |
+
metrics:
|
1758 |
+
- type: cos_sim_accuracy
|
1759 |
+
value: 87.53832421314084
|
1760 |
+
- type: cos_sim_ap
|
1761 |
+
value: 83.11416594316626
|
1762 |
+
- type: cos_sim_f1
|
1763 |
+
value: 75.41118114347518
|
1764 |
+
- type: cos_sim_precision
|
1765 |
+
value: 73.12839059674504
|
1766 |
+
- type: cos_sim_recall
|
1767 |
+
value: 77.8410840776101
|
1768 |
+
- type: dot_accuracy
|
1769 |
+
value: 87.53832421314084
|
1770 |
+
- type: dot_ap
|
1771 |
+
value: 83.11416226342155
|
1772 |
+
- type: dot_f1
|
1773 |
+
value: 75.41118114347518
|
1774 |
+
- type: dot_precision
|
1775 |
+
value: 73.12839059674504
|
1776 |
+
- type: dot_recall
|
1777 |
+
value: 77.8410840776101
|
1778 |
+
- type: euclidean_accuracy
|
1779 |
+
value: 87.53832421314084
|
1780 |
+
- type: euclidean_ap
|
1781 |
+
value: 83.11416284455395
|
1782 |
+
- type: euclidean_f1
|
1783 |
+
value: 75.41118114347518
|
1784 |
+
- type: euclidean_precision
|
1785 |
+
value: 73.12839059674504
|
1786 |
+
- type: euclidean_recall
|
1787 |
+
value: 77.8410840776101
|
1788 |
+
- type: manhattan_accuracy
|
1789 |
+
value: 87.49369348391353
|
1790 |
+
- type: manhattan_ap
|
1791 |
+
value: 83.08066812574694
|
1792 |
+
- type: manhattan_f1
|
1793 |
+
value: 75.36561228603892
|
1794 |
+
- type: manhattan_precision
|
1795 |
+
value: 71.9202518363064
|
1796 |
+
- type: manhattan_recall
|
1797 |
+
value: 79.15768401601478
|
1798 |
+
- type: max_accuracy
|
1799 |
+
value: 87.53832421314084
|
1800 |
+
- type: max_ap
|
1801 |
+
value: 83.11416594316626
|
1802 |
+
- type: max_f1
|
1803 |
+
value: 75.41118114347518
|
1804 |
---
|
1805 |
+
|
1806 |
+
# lodestone-base-4096-v1
|
1807 |
+
|
1808 |
+
This new [sentence-transformers](https://www.SBERT.net) model from [Hum](https://www.hum.works/) maps long sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.
|
1809 |
+
|
1810 |
+
## Abstract
|
1811 |
+
|
1812 |
+
In the hopes of furthering Hum's overarching mission of increasing the accessibility and interconnectivity of human knowledge, this model was developed as part of a project intending to boost the maximum input sequence length of sentence embedding models by leveraging recent architectural advances in the design of transformer models such as the incorporation of FlashAttention, Attention with Linear Biases (ALiBi), and Gated Linear Units (GLU). These modifications and enhancements were implemented by the team at MosaicML who designed and constructed the pre-trained [`mosaic-bert-base-seqlen-2048`](https://huggingface.co/mosaicml/mosaic-bert-base-seqlen-2048) model, and more information regarding the details of their development and testing specifications can be found on the model card.
|
1813 |
+
|
1814 |
+
While the fine-tuning procedure followed during the course of this project loosely mirrors that of the of the original [Flax-sentence-embeddings](https://huggingface.co/flax-sentence-embeddings) team responsible for the creation of many other popular sentence-transformers models (e.g. [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2), [all-distilroberta-v1](https://huggingface.co/sentence-transformers/all-distilroberta-v1), and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)), our methodology includes novel techniques for data loading, batch sampling, and model checkpointing intended to improve training efficiency with regards to memory allocation and data storage.
|
1815 |
+
|
1816 |
+
Through combining these well-established and proven fine-tuning practices with novel advances in transformer architectural elements, our `lodestone-base-4096-v1` model is able to achieve comparable performance metrics on standard text embedding evaluation benchmarks while also supporting a longer and more robust input sequence length of 4096 while retaining a smaller, more manageable size capable of being run on either a GPU or CPU.
|
1817 |
+
|
1818 |
+
## Usage
|
1819 |
+
|
1820 |
+
Using this model becomes relatively easy when you have [sentence-transformers](https://www.SBERT.net) installed.
|
1821 |
+
*At the time of publishing, sentence-transformers does not support remote code which is required for flash-attention used by the model. A fork of the sentence-transformers repository that allows remote code execution is provided for convenience. It can be installed using the following command:*
|
1822 |
+
```
|
1823 |
+
pip install git+https://github.com/Hum-Works/sentence-transformers.git
|
1824 |
+
```
|
1825 |
+
|
1826 |
+
Then you can use the model like this:
|
1827 |
+
```python
|
1828 |
+
from sentence_transformers import SentenceTransformer
|
1829 |
+
|
1830 |
+
model = SentenceTransformer('lodestone-base-4096-v1', trust_remote_code=True, revision='v1.0.0')
|
1831 |
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
1832 |
+
embeddings = model.encode(sentences)
|
1833 |
+
print(embeddings)
|
1834 |
+
```
|
1835 |
+
*Note: The model will use the openAI/Triton implementation of FlashAttention if installed. This is more performant than the fallback, torch implementation. Some platforms and GPUs may not be supported by Triton - up to date compatibility can be found on [Triton’s github page](https://github.com/openai/triton#compatibility).*
|
1836 |
+
|
1837 |
+
------
|
1838 |
+
|
1839 |
+
## Background
|
1840 |
+
|
1841 |
+
The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised contrastive learning objective. We used the pretrained [`mosaic-bert-base-seqlen-2048`](https://huggingface.co/mosaicml/mosaic-bert-base-seqlen-2048) model and fine-tuned it on a nearly 1.5B sentence pairs dataset. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
|
1842 |
+
|
1843 |
+
## Intended uses
|
1844 |
+
|
1845 |
+
Our model is intended to be used as a long sentence and paragraph encoder. Given an input text, it outputs a vector containing the semantic information. The sentence vector may be used for information retrieval, clustering, or sentence similarity tasks.
|
1846 |
+
|
1847 |
+
## Training procedure
|
1848 |
+
|
1849 |
+
### Pre-training
|
1850 |
+
|
1851 |
+
We use the pretrained [`mosaic-bert-base-seqlen-2048`](https://huggingface.co/mosaicml/mosaic-bert-base-seqlen-2048). Please refer to the model card for more detailed information about the pre-training procedure.
|
1852 |
+
|
1853 |
+
### Fine-tuning
|
1854 |
+
|
1855 |
+
We fine-tune the model using a contrastive objective. Formally, we compute the dot product of each possible sentence pairing in the batch. We then apply the cross entropy loss by comparing with true pairs.
|
1856 |
+
|
1857 |
+
#### Hyperparameters
|
1858 |
+
|
1859 |
+
We trained our model on an ml.g5.4xlarge EC2 instance with 1 NVIDIA A10G Tensor Core GPU. We train the model during 1.4 million steps using a batch size of 16. We use a learning rate warm up of 500. The sequence length during training was limited to 2048 tokens. We used the AdamW optimizer with a 2e-5 learning rate and weight decay of 0.01 (i.e. the default parameter values for SentenceTransformer.fit()). The full training script is accessible in this current repository: `Training.py`.
|
1860 |
+
|
1861 |
+
## Model Architecture
|
1862 |
+
By incorporating FlashAttention, [Attention with Linear Biases (ALiBi)](https://arxiv.org/abs/2108.12409), and Gated Linear Units (GLU), this model is able to handle input sequences of 4096, 8x longer than that supported by most comparable sentence embedding models.
|
1863 |
+
The model was trained using a sequence length maximum of 2048, but the final model has a maximum sequence length of 4096. This is accomplished by taking advantage of ALiBi’s positional attention extrapolation which has been shown to allow sequence lengths of 2x the initial trained length.
|
1864 |
+
|
1865 |
+
## Full Model Architecture
|
1866 |
+
```
|
1867 |
+
SentenceTransformer(
|
1868 |
+
(0): Transformer({'max_seq_length': 4096, 'do_lower_case': False}) with Transformer model: BertModel
|
1869 |
+
(1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})
|
1870 |
+
(2): Normalize()
|
1871 |
+
)
|
1872 |
+
```
|
1873 |
+
|
1874 |
+
#### Training data
|
1875 |
+
|
1876 |
+
We use the concatenation from multiple datasets to fine-tune our model. The total number of sentence pairs is nearly 1.5 billion sentences. We sampled each dataset given a weighted probability proportional to its relative contribution to the entire dataset.
|
1877 |
+
The breakdown of the dataset can be seen below, and the entire dataset can be publicly accessed and uploaded via the `Dataloading.ipynb` located within this repository.
|
1878 |
+
|
1879 |
+
| Dataset | Paper | Number of training tuples |
|
1880 |
+
|--------------------------------------------------------|:----------------------------------------:|:--------------------------:|
|
1881 |
+
| [Reddit comments (2015-2018)](https://github.com/PolyAI-LDN/conversational-datasets/tree/master/reddit) | [paper](https://arxiv.org/abs/1904.06472) | 726,484,430 |
|
1882 |
+
| **[S2ORC](https://github.com/allenai/s2orc) Citation pairs (Abstracts)** | [paper](https://aclanthology.org/2020.acl-main.447/) | 252,102,397 |
|
1883 |
+
| **[Reddit posts](https://huggingface.co/datasets/sentence-transformers/reddit-title-body) (Title, Body) pairs** | - | 127,445,911 |
|
1884 |
+
| **[Amazon reviews (2018)](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (Title, Review) pairs** | - | 87,877,725 |
|
1885 |
+
| [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs | [paper](https://doi.org/10.1145/2623330.2623677) | 77,427,422 |
|
1886 |
+
| [PAQ](https://github.com/facebookresearch/PAQ) (Question, Answer) pairs | [paper](https://arxiv.org/abs/2102.07033) | 64,371,441 |
|
1887 |
+
| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Titles) | [paper](https://aclanthology.org/2020.acl-main.447/) | 52,603,982 |
|
1888 |
+
| [S2ORC](https://github.com/allenai/s2orc) (Title, Abstract) | [paper](https://aclanthology.org/2020.acl-main.447/) | 41,769,185 |
|
1889 |
+
| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_title_body_jsonl) (Title, Body) pairs | - | 25,368,423 |
|
1890 |
+
| [MS MARCO](https://microsoft.github.io/msmarco/) triplets | [paper](https://doi.org/10.1145/3404835.3462804) | 9,144,553 |
|
1891 |
+
| **[Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_title_best_voted_answer_jsonl) (Title, Most Upvoted Answer) pairs** | - | 4,784,250 |
|
1892 |
+
| **[Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_titlebody_best_voted_answer_jsonl) (Title+Body, Most Upvoted Answer) pairs** | - | 4,551,660 |
|
1893 |
+
| [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) | [paper](https://arxiv.org/pdf/2104.08727.pdf) | 3,012,496 |
|
1894 |
+
| **[Amazon QA](https://huggingface.co/datasets/sentence-transformers/embedding-training-data)** | - | 2,507,114 |
|
1895 |
+
| [Code Search](https://huggingface.co/datasets/code_search_net) | - | 1,375,067 |
|
1896 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 1,198,260 |
|
1897 |
+
| **[AG News]((Title, Description) pairs of news articles from the AG News dataset)** | - | 1,157,745 |
|
1898 |
+
| [COCO](https://cocodataset.org/#home) Image captions | [paper](https://link.springer.com/chapter/10.1007%2F978-3-319-10602-1_48) | 828,395|
|
1899 |
+
| [SPECTER](https://github.com/allenai/specter) citation triplets | [paper](https://doi.org/10.18653/v1/2020.acl-main.207) | 684,100 |
|
1900 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 681,164 |
|
1901 |
+
| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 659,896 |
|
1902 |
+
| **[CC News](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (Title, article) pairs** | - | 614,664 |
|
1903 |
+
| **[NPR](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (Title, Body) pairs** | - | 594,384 |
|
1904 |
+
| [SearchQA](https://huggingface.co/datasets/search_qa) | [paper](https://arxiv.org/abs/1704.05179) | 582,261 |
|
1905 |
+
| **[MS Marco](https://microsoft.github.io/msmarco/) (Query, Answer Passage) pairs** | [paper](https://doi.org/10.1145/3404835.3462804) | 532,751 |
|
1906 |
+
| [Stack Exchange](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) (Title, Body) pairs | - | 364,000 |
|
1907 |
+
| [Eli5](https://huggingface.co/datasets/eli5) | [paper](https://doi.org/10.18653/v1/p19-1346) | 325,475 |
|
1908 |
+
| [Flickr 30k](https://shannon.cs.illinois.edu/DenotationGraph/) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/229/33) | 317,695 |
|
1909 |
+
| **[CNN & DailyMail](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (highlight sentences, article) pairs** | - | 311,971 |
|
1910 |
+
| [Stack Exchange](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) Duplicate questions (titles) | - | 304,524 |
|
1911 |
+
| AllNLI ([SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) | [paper SNLI](https://doi.org/10.18653/v1/d15-1075), [paper MultiNLI](https://doi.org/10.18653/v1/n18-1101) | 277,230 |
|
1912 |
+
| [Stack Exchange](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) Duplicate questions (bodies) | - | 250,518 |
|
1913 |
+
| [Stack Exchange](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) Duplicate questions (titles+bodies) | - | 250,459 |
|
1914 |
+
| **[XSUM](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) (Summary, News Article) pairs** | - | 226,711 |
|
1915 |
+
| **[Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_titlebody_best_and_down_voted_answer_jsonl) (Title+Body, Most Upvoted Answer, Most Downvoted Answer) triplets** | - | 216,454 |
|
1916 |
+
| [Sentence Compression](https://github.com/google-research-datasets/sentence-compression) | [paper](https://www.aclweb.org/anthology/D13-1155/) | 180,000 |
|
1917 |
+
| **[FEVER](https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0) training data** | - | 139,051 |
|
1918 |
+
| [Wikihow](https://github.com/pvl/wikihow_pairs_dataset) | [paper](https://arxiv.org/abs/1810.09305) | 128,542 |
|
1919 |
+
| **[SearchQA](https://huggingface.co/datasets/search_qa) (Question, Top-Snippet)** | [paper](https://arxiv.org/abs/1704.05179) | 117,384 |
|
1920 |
+
| [Altlex](https://github.com/chridey/altlex/) | [paper](https://aclanthology.org/P16-1135.pdf) | 112,696 |
|
1921 |
+
| **[Quora Question Duplicates](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs)** | - | 103,663 |
|
1922 |
+
| [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) | - | 103,663 |
|
1923 |
+
| [Simple Wikipedia](https://cs.pomona.edu/~dkauchak/simplification/) | [paper](https://www.aclweb.org/anthology/P11-2117/) | 102,225 |
|
1924 |
+
| [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/1455) | 100,231 |
|
1925 |
+
| [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) | [paper](https://aclanthology.org/P18-2124.pdf) | 87,599 |
|
1926 |
+
| [TriviaQA](https://huggingface.co/datasets/trivia_qa) | - | 73,346 |
|
1927 |
+
| **Total** | | **1,492,453,113** |
|
1928 |
+
|
1929 |
+
#### Replication
|
1930 |
+
|
1931 |
+
The entire fine-tuning process for this model can be replicated by following the steps outlined in the `Replication.txt` file within this repository. This document explains how to modify the [sentence-transformers](https://www.SBERT.net) library, configure the pre-trained [`mosaic-bert-base-seqlen-2048`](https://huggingface.co/mosaicml/mosaic-bert-base-seqlen-2048) model, load all of the training data, and execute the training script.
|
1932 |
+
|
1933 |
+
#### Limitations
|
1934 |
+
|
1935 |
+
Due to technical constraints (e.g. limited GPU memory capacity), this model was trained with a smaller batch size of 16, making it so that each step during training was less well-informed than it would have been on a higher performance system. This smaller than ideal hyperparameter value will generally cause the model to be more likely to get stuck in a local minimum and for the parameter configuration to take a longer time to converge to the optimum. In order to counteract this potential risk, we trained the model for a larger number of steps than many of its contemporaries to ensure a greater chance of achieving strong performance, but this is an area which could be improved if further fine-tuning was performed.
|
1936 |
+
|
1937 |
+
It is also worth noting that, while this model is able to handle longer input sequences of up to 4096 word pieces, the training dataset used consists of sentence and paragraph pairs and triplets which do not necessarily reach that maximum sequence length. Since the data was not tailored specifically for this larger input size, further fine-tuning may be required to ensure highly accurate embeddings for longer texts of that magnitude.
|
1938 |
+
|
1939 |
+
Finally, as stated on https://huggingface.co/datasets/sentence-transformers/reddit-title-body, an additional reminder and warning regarding the Reddit posts data is that one should "Be aware that this dataset is not filtered for biases, hate-speech, spam, racial slurs etc. It depicts the content as it is posted on Reddit." Thus, while we believe this has not induced any pathological behaviors in the model's performance due to its relatively low prevalence of records in the whole dataset of nearly 1.5B sentence pairs and the fact that this model was trained to produce semantic embeddings rather than generative text outputs, it is always important to be aware of vulnerabilities to bias.
|
1940 |
+
|
Replication.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Lodestone Replication
|
2 |
+
|
3 |
+
The dataloading, library modification, model preparation, and training process can be replicated in a straightforward manner by simply running a few Jupyter notebooks and Python files.
|
4 |
+
|
5 |
+
Data Wrangling and Loading
|
6 |
+
|
7 |
+
Dataloading.ipynb utilizes the contents of the GoogleSheets_datasets.tsv and HuggingFace_datasets.tsv to fetch data from various URLs provided by the original distilroberta team to their curated datasets in cloud storage. The data is then streamed directly into the data folder of the lodestone-rnd S3 bucket in us-east-1. In addition to the data used by the distilroberta team and provided at https://docs.google.com/spreadsheets/d/1vXJrIg38cEaKjOG5y4I4PQwAQFUmCkohbViJ9zj_Emg/edit#gid=0, data was also collected from https://huggingface.co/datasets/sentence-transformers/embedding-training-data and the following HuggingFace dataset repositories:
|
8 |
+
Stack Exchange
|
9 |
+
https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_title_body_jsonl
|
10 |
+
https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_titlebody_best_voted_answer_jsonl
|
11 |
+
https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_title_best_voted_answer_jsonl
|
12 |
+
https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_titlebody_best_and_down_voted_answer_jsonl
|
13 |
+
Reddit
|
14 |
+
https://huggingface.co/datasets/sentence-transformers/reddit-title-body
|
15 |
+
All of the HuggingFace data is handled remotely or pulled via the script in Dataloading.ipynb, so the only files required for this entire process are Dataloading.ipynb, GoogleSheets_datasets.tsv, and HuggingFace_datasets.tsv. Running this notebook results in 679 objects and 310.5GB of data being loaded into S3.
|
16 |
+
|
17 |
+
Once the data is in S3, run Data_Records.ipynb to generate the data_records.json file which contains a dictionary of {filename: record count} pairs and is used throughout the Training.py script.
|
18 |
+
|
19 |
+
Library and Model Preparation
|
20 |
+
|
21 |
+
In order to run the training process with our specific model, we need to make a few custom modifications to the sentence-transformers library and to the config.json file of the mosaic-bert-base-seqlen-2048 base model.
|
22 |
+
|
23 |
+
To alter the sentence-transformers library, clone the repository from https://github.com/UKPLab/sentence-transformers locally and replace the SentenceTransformer.py and Transformer.py files located within the sentence-transformers/sentence_transformers/ and sentence-transformers/sentence_transformers/models/ directories of the cloned repository, respectively, with those located inside dev/ folder. (This has already been done in this notebook instance, but this will have to be completed if training on another system.)
|
24 |
+
|
25 |
+
Before conducting actual training, we also need to clone the mosaic-bert-base-seqlen-2048 model locally and make a few small changes to its config.json file. Running Mosaic_Model.ipynb will execute this process and get our model ready to begin training. (Again, this has already been done in this notebook instance, but this will have to be completed if training on another system.)
|
26 |
+
|
27 |
+
Training
|
28 |
+
|
29 |
+
To perform the final training run, open a SageMaker Terminal window and execute the following:
|
30 |
+
cd SageMaker
|
31 |
+
screen -S training
|
32 |
+
python Training.py
|
33 |
+
^a d (that is, Ctrl + a, then d)
|
34 |
+
|
35 |
+
To reattach to the screen and observe how training is progressing, run `screen -r training` in the Terminal. Occasionally epochs may stall and require manual intervention to kickstart the process again. Pressing ^c (that is, Ctrl+c) inside the screen should suffice the get things going again, but this action will automatically cause the currently stalled epoch to fail and for the training to proceed to the next epoch or data chunk without updating the existing model parameterization. Epoch successes and failures and the cumulative number of successfully completed steps can be monitored via the train_logs.txt file which is updated automatically throughout the course of training.
|
36 |
+
|
37 |
+
The Training.py file can be reconfigured such that training hyperparameters could be passed in through the command line, but, at present, hyperparameters should be set within the file before running it.
|
38 |
+
|
39 |
+
This concludes the steps required for replication of the Lodestone training process.
|
40 |
+
|
Training.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This training script is a duplicate of the Training.ipynb notebook but can be invoked from the terminal
|
2 |
+
|
3 |
+
import os
|
4 |
+
print(os.getcwd())
|
5 |
+
os.environ["PATH"]="/usr/local/cuda-11.7/bin:"+os.getenv("PATH")
|
6 |
+
|
7 |
+
os.system('pip uninstall -y torch')
|
8 |
+
os.system('pip uninstall -y einops')
|
9 |
+
os.system('pip uninstall -y transformers')
|
10 |
+
os.system('pip uninstall -y sentence_transformers')
|
11 |
+
os.system('pip uninstall -y datasets')
|
12 |
+
os.system('pip uninstall -y sagemaker')
|
13 |
+
os.system('pip uninstall -y smart_open')
|
14 |
+
os.system('pip uninstall -y pynvml')
|
15 |
+
|
16 |
+
os.system('pip install -r lodestone-reqs.txt')
|
17 |
+
|
18 |
+
os.system('pip install -e ./sentence-transformers')
|
19 |
+
|
20 |
+
os.system('pip uninstall -y triton')
|
21 |
+
os.system('pip install --no-deps triton==2.0.0.dev20221202')
|
22 |
+
|
23 |
+
#####
|
24 |
+
|
25 |
+
from pynvml import *
|
26 |
+
import math
|
27 |
+
from sentence_transformers import models, losses
|
28 |
+
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
|
29 |
+
import logging
|
30 |
+
import os
|
31 |
+
import json
|
32 |
+
import torch
|
33 |
+
import boto3
|
34 |
+
from smart_open import open
|
35 |
+
import random
|
36 |
+
import time
|
37 |
+
import gc
|
38 |
+
|
39 |
+
os.environ["PATH"]="/usr/local/cuda-11.7/bin:"+os.getenv("PATH")
|
40 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
41 |
+
|
42 |
+
#####
|
43 |
+
|
44 |
+
|
45 |
+
def print_gpu_utilization():
|
46 |
+
"This helper function outputs the current GPU memory usage."
|
47 |
+
nvmlInit()
|
48 |
+
handle = nvmlDeviceGetHandleByIndex(0)
|
49 |
+
info = nvmlDeviceGetMemoryInfo(handle)
|
50 |
+
return f"GPU memory occupied: {info.used/1024**3} GB."
|
51 |
+
|
52 |
+
#####
|
53 |
+
|
54 |
+
|
55 |
+
class MultiDatasetDataLoader:
|
56 |
+
"""
|
57 |
+
This custom dataloader class consumes a list of datasets and a batch size and produces batches randomly sampled
|
58 |
+
from the datasets provided where each batch consists of records from a single dataset and datasets are chosen
|
59 |
+
for batches in proportion to their total number of records.
|
60 |
+
"""
|
61 |
+
def __init__(self, datasets, batch_size_pairs, batch_size_triplets=None, dataset_size_temp=-1, allow_swap=True):
|
62 |
+
self.allow_swap = allow_swap
|
63 |
+
self.batch_size_pairs = batch_size_pairs
|
64 |
+
self.batch_size_triplets = batch_size_pairs if batch_size_triplets is None else batch_size_triplets
|
65 |
+
|
66 |
+
# Compute dataset weights
|
67 |
+
self.dataset_lengths = list(map(len, datasets))
|
68 |
+
self.dataset_lengths_sum = sum(self.dataset_lengths)
|
69 |
+
|
70 |
+
weights = []
|
71 |
+
# if dataset_size_temp > 0: # Scale probability with dataset size
|
72 |
+
# for dataset in datasets:
|
73 |
+
# prob = len(dataset) / self.dataset_lengths_sum
|
74 |
+
# weights.append(max(1, int(math.pow(prob, 1 / dataset_size_temp) * 1000)))
|
75 |
+
# else: # Equal weighting of all datasets
|
76 |
+
# weights = [100] * len(datasets)
|
77 |
+
for dataset in datasets:
|
78 |
+
weights.append(len(dataset))
|
79 |
+
|
80 |
+
# logging.info("Dataset lengths and weights: {}".format(list(zip(self.dataset_lengths, weights))))
|
81 |
+
|
82 |
+
self.dataset_idx = []
|
83 |
+
self.dataset_idx_pointer = 0
|
84 |
+
|
85 |
+
for idx, weight in enumerate(weights):
|
86 |
+
self.dataset_idx.extend([idx] * weight)
|
87 |
+
random.shuffle(self.dataset_idx)
|
88 |
+
|
89 |
+
self.datasets = []
|
90 |
+
for dataset in datasets:
|
91 |
+
random.shuffle(dataset)
|
92 |
+
self.datasets.append({
|
93 |
+
'elements': dataset,
|
94 |
+
'pointer': 0,
|
95 |
+
})
|
96 |
+
|
97 |
+
def __iter__(self):
|
98 |
+
for _ in range(int(self.__len__())):
|
99 |
+
# Select dataset
|
100 |
+
if self.dataset_idx_pointer >= len(self.dataset_idx):
|
101 |
+
self.dataset_idx_pointer = 0
|
102 |
+
random.shuffle(self.dataset_idx)
|
103 |
+
|
104 |
+
dataset_idx = self.dataset_idx[self.dataset_idx_pointer]
|
105 |
+
self.dataset_idx_pointer += 1
|
106 |
+
|
107 |
+
# Select batch from this dataset
|
108 |
+
dataset = self.datasets[dataset_idx]
|
109 |
+
batch_size = self.batch_size_pairs if len(dataset['elements'][0].texts) == 2 else self.batch_size_triplets
|
110 |
+
|
111 |
+
batch = []
|
112 |
+
texts_in_batch = set()
|
113 |
+
guid_in_batch = set()
|
114 |
+
while len(batch) < batch_size:
|
115 |
+
example = dataset['elements'][dataset['pointer']]
|
116 |
+
|
117 |
+
valid_example = True
|
118 |
+
# First check if one of the texts in already in the batch
|
119 |
+
for text in example.texts:
|
120 |
+
text_norm = text.strip().lower()
|
121 |
+
if text_norm in texts_in_batch:
|
122 |
+
valid_example = False
|
123 |
+
|
124 |
+
texts_in_batch.add(text_norm)
|
125 |
+
|
126 |
+
# If the example has a label, check if label is in batch
|
127 |
+
if example.guid is not None:
|
128 |
+
valid_example = valid_example and example.guid not in guid_in_batch
|
129 |
+
guid_in_batch.add(example.guid)
|
130 |
+
|
131 |
+
if valid_example:
|
132 |
+
if self.allow_swap and random.random() > 0.5:
|
133 |
+
example.texts[0], example.texts[1] = example.texts[1], example.texts[0]
|
134 |
+
|
135 |
+
batch.append(example)
|
136 |
+
|
137 |
+
dataset['pointer'] += 1
|
138 |
+
if dataset['pointer'] >= len(dataset['elements']):
|
139 |
+
dataset['pointer'] = 0
|
140 |
+
random.shuffle(dataset['elements'])
|
141 |
+
|
142 |
+
yield self.collate_fn(batch) if self.collate_fn is not None else batch
|
143 |
+
|
144 |
+
def __len__(self):
|
145 |
+
return int(self.dataset_lengths_sum / self.batch_size_pairs)
|
146 |
+
|
147 |
+
#####
|
148 |
+
|
149 |
+
|
150 |
+
# These four classes of custom generators parse the raw data from the files in S3 and format it into InputExamples which can be properly interpreted by a SentenceTransformer model.
|
151 |
+
|
152 |
+
class RedditTitleBodyDataset:
|
153 |
+
def __init__(self, source_uri, max_seq_length):
|
154 |
+
self.source_uri = source_uri
|
155 |
+
self.s3_client = boto3.client("s3")
|
156 |
+
self.max_seq_length = max_seq_length
|
157 |
+
|
158 |
+
def __iter__(self):
|
159 |
+
while True:
|
160 |
+
for json_line in open(self.source_uri, transport_params={"client": self.s3_client}):
|
161 |
+
data_line = json.loads(json_line.strip())
|
162 |
+
|
163 |
+
if "title" in data_line and "body" in data_line:
|
164 |
+
data = {'guid': None, 'texts': [" ".join(data_line['title'].split(" ")[:self.max_seq_length]), " ".join(data_line['body'].split(" ")[:self.max_seq_length])]}
|
165 |
+
record = InputExample(guid=data.get('guid', None), texts=data['texts'])
|
166 |
+
|
167 |
+
yield record
|
168 |
+
|
169 |
+
|
170 |
+
class RedditYearDataset:
|
171 |
+
def __init__(self, source_uri, max_seq_length):
|
172 |
+
self.source_uri = source_uri
|
173 |
+
self.s3_client = boto3.client("s3")
|
174 |
+
self.max_seq_length = max_seq_length
|
175 |
+
|
176 |
+
def __iter__(self):
|
177 |
+
while True:
|
178 |
+
for json_line in open(self.source_uri, transport_params={"client": self.s3_client}):
|
179 |
+
data_line = json.loads(json_line.strip())
|
180 |
+
|
181 |
+
if "response" in data_line and "context" in data_line:
|
182 |
+
data = {'guid': None, 'texts': [" ".join(data_line['response'].split(" ")[:self.max_seq_length]), " ".join(data_line['context'].split(" ")[:self.max_seq_length])]}
|
183 |
+
record = InputExample(guid=data.get('guid', None), texts=data['texts'])
|
184 |
+
|
185 |
+
yield record
|
186 |
+
|
187 |
+
|
188 |
+
class HuggingFaceQueryPosDataset:
|
189 |
+
def __init__(self, source_uri, max_seq_length):
|
190 |
+
self.source_uri = source_uri
|
191 |
+
self.s3_client = boto3.client("s3")
|
192 |
+
self.max_seq_length = max_seq_length
|
193 |
+
|
194 |
+
def __iter__(self):
|
195 |
+
while True:
|
196 |
+
for json_line in open(self.source_uri, transport_params={"client": self.s3_client}):
|
197 |
+
data_line = json.loads(json_line.strip())
|
198 |
+
|
199 |
+
if "query" in data_line and "pos" in data_line:
|
200 |
+
for i in range(len(data_line['pos'])):
|
201 |
+
data = {'guid': None, 'texts': [" ".join(data_line['query'].split(" ")[:self.max_seq_length]), " ".join(data_line['pos'][i].split(" ")[:self.max_seq_length])]}
|
202 |
+
record = InputExample(guid=data.get('guid', None), texts=data['texts'])
|
203 |
+
|
204 |
+
yield record
|
205 |
+
|
206 |
+
|
207 |
+
class Dataset:
|
208 |
+
def __init__(self, source_uri, max_seq_length):
|
209 |
+
self.source_uri = source_uri
|
210 |
+
self.s3_client = boto3.client("s3")
|
211 |
+
self.max_seq_length = max_seq_length
|
212 |
+
|
213 |
+
def __iter__(self):
|
214 |
+
while True:
|
215 |
+
for json_line in open(self.source_uri, transport_params={"client": self.s3_client}):
|
216 |
+
data_line = json.loads(json_line.strip())
|
217 |
+
|
218 |
+
if not isinstance(data_line, dict):
|
219 |
+
data = {'guid': None, 'texts': data_line}
|
220 |
+
for text_idx in range(len(data['texts'])):
|
221 |
+
data['texts'][text_idx] = " ".join(data['texts'][text_idx].split(" ")[:self.max_seq_length])
|
222 |
+
record = InputExample(guid=data.get('guid', None), texts=data['texts'])
|
223 |
+
else:
|
224 |
+
for text_idx in range(len(data_line['texts'])):
|
225 |
+
data_line['texts'][text_idx] = " ".join(data_line['texts'][text_idx].split(" ")[:self.max_seq_length])
|
226 |
+
record = InputExample(guid=data_line.get('guid', None), texts=data_line['texts'])
|
227 |
+
|
228 |
+
yield record
|
229 |
+
|
230 |
+
#####
|
231 |
+
|
232 |
+
|
233 |
+
def build_generators(data_records, max_seq_length=512, testing=False):
|
234 |
+
"""
|
235 |
+
This function consumes the data_records dictionary and creates a new dictionary of data generators where each entry is
|
236 |
+
of the form {filename: data generator object}.
|
237 |
+
"""
|
238 |
+
if testing:
|
239 |
+
# filepaths = [file for file in list(data_records.keys()) if file.startswith('S2ORC') or file.startswith('reddit_')]
|
240 |
+
filepaths = [file for file in list(data_records.keys())][:3]
|
241 |
+
else:
|
242 |
+
filepaths = list(data_records.keys())
|
243 |
+
generators = {}
|
244 |
+
for filepath in filepaths:
|
245 |
+
filepath = filepath.strip()
|
246 |
+
source_uri = 's3://lodestone-rnd/data/'+filepath
|
247 |
+
if filepath in ['S2ORC_citations_abstracts.json.gz', 'amazon-qa.json.gz'] or 'reddit' in filepath:
|
248 |
+
if "title" in filepath:
|
249 |
+
generators[f'{filepath.split(".")[0]}'] = iter(RedditTitleBodyDataset(source_uri, max_seq_length))
|
250 |
+
elif "reddit" in filepath:
|
251 |
+
generators[f'{filepath.split(".")[0]}'] = iter(RedditYearDataset(source_uri, max_seq_length))
|
252 |
+
else:
|
253 |
+
generators[f'{filepath.split(".")[0]}'] = iter(HuggingFaceQueryPosDataset(source_uri, max_seq_length))
|
254 |
+
else:
|
255 |
+
generators[f'{filepath.split(".")[0]}'] = iter(Dataset(source_uri, max_seq_length))
|
256 |
+
|
257 |
+
return generators
|
258 |
+
|
259 |
+
#####
|
260 |
+
|
261 |
+
|
262 |
+
def produce_data(data_records, num_chunks, generators, batch_size, failed_on=None, first_iter=False, testing=False, temp=-1):
|
263 |
+
"""
|
264 |
+
This function consumes the data_records dictionary, the number of chunks to break the datasets into, the dictionary of
|
265 |
+
data generators, and a batch size and returns a MultiDatasetDataloader which can be fed into the .fit method of a
|
266 |
+
SentenceTransformer model.
|
267 |
+
"""
|
268 |
+
if testing:
|
269 |
+
# filepaths = [file for file in list(data_records.keys()) if file.startswith('S2ORC') or file.startswith('reddit_')]
|
270 |
+
filepaths = [file for file in list(data_records.keys())][:3]
|
271 |
+
else:
|
272 |
+
filepaths = list(data_records.keys())
|
273 |
+
datasets = []
|
274 |
+
for file_idx, filepath in enumerate(filepaths):
|
275 |
+
filepath = filepath.strip()
|
276 |
+
dataset = []
|
277 |
+
|
278 |
+
if failed_on is not None and failed_on != 1 and first_iter:
|
279 |
+
for k in range((failed_on-1)*max(1, data_records[filepath]//num_chunks)):
|
280 |
+
next(generators[f'{filepath.split(".")[0]}'])
|
281 |
+
for m in range(max(1, data_records[filepath]//num_chunks)):
|
282 |
+
dataset.append(next(generators[f'{filepath.split(".")[0]}']))
|
283 |
+
else:
|
284 |
+
for n in range(max(1, data_records[filepath]//num_chunks)):
|
285 |
+
dataset.append(next(generators[f'{filepath.split(".")[0]}']))
|
286 |
+
|
287 |
+
datasets.append(dataset)
|
288 |
+
logging.info("{}. {}: {}".format(file_idx+1, filepath, len(dataset)))
|
289 |
+
|
290 |
+
dataset_lengths_sum = sum(list(map(len, datasets)))
|
291 |
+
|
292 |
+
batch_size_pairs = batch_size_triplets = batch_size
|
293 |
+
# Special data loader to load from multiple datasets
|
294 |
+
train_dataloader = MultiDatasetDataLoader(datasets=datasets,
|
295 |
+
batch_size_pairs=batch_size_pairs,
|
296 |
+
batch_size_triplets=batch_size_triplets,
|
297 |
+
dataset_size_temp=temp)
|
298 |
+
|
299 |
+
return train_dataloader, dataset_lengths_sum
|
300 |
+
|
301 |
+
#####
|
302 |
+
|
303 |
+
|
304 |
+
def construct_model(model_name, max_seq_length=512):
|
305 |
+
"""
|
306 |
+
This function constructs a SentenceTransformer model from a HuggingFace transformer model name
|
307 |
+
or from a local path to a transformer model repository.
|
308 |
+
"""
|
309 |
+
word_embedding_model = models.Transformer(model_name_or_path=model_name,
|
310 |
+
max_seq_length=max_seq_length,
|
311 |
+
tokenizer_name_or_path='bert-base-uncased',
|
312 |
+
trust_remote_code=True,
|
313 |
+
model_args={'torch_dtype': torch.bfloat16})
|
314 |
+
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
|
315 |
+
norm = models.Normalize()
|
316 |
+
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, norm], device='cuda')
|
317 |
+
model[0].tokenizer.model_max_length = max_seq_length
|
318 |
+
|
319 |
+
return model
|
320 |
+
|
321 |
+
#####
|
322 |
+
|
323 |
+
|
324 |
+
# Just some code to print debug information to stdout
|
325 |
+
logging.basicConfig(format='%(asctime)s - %(message)s',
|
326 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
327 |
+
level=logging.INFO,
|
328 |
+
handlers=[LoggingHandler()])
|
329 |
+
# /print debug information to stdout
|
330 |
+
|
331 |
+
#####
|
332 |
+
|
333 |
+
|
334 |
+
# Set Hyperparameters
|
335 |
+
model_name = 'mosaic-bert-base-seqlen-2048'
|
336 |
+
# model_name = 'hum-lodestone-v1'
|
337 |
+
batch_size = 16
|
338 |
+
batch_size_pairs = batch_size_triplets = batch_size
|
339 |
+
max_seq_length = 2048
|
340 |
+
use_amp = False
|
341 |
+
|
342 |
+
num_cycles = 2
|
343 |
+
num_chunks = 50
|
344 |
+
num_epochs = 2
|
345 |
+
steps_per_epoch = 10000
|
346 |
+
# Total training steps = num_cycles * num_chunks * num_epochs * steps_per_epoch = 2 * 50 * 2 * 10,000 = 2,000,000 steps
|
347 |
+
warmup_steps = 500
|
348 |
+
|
349 |
+
testing = False
|
350 |
+
temp = -1
|
351 |
+
|
352 |
+
#####
|
353 |
+
|
354 |
+
|
355 |
+
output_path = 'hum-lodestone-v1'
|
356 |
+
logging.info("Output: "+output_path)
|
357 |
+
|
358 |
+
# Instantiate SentenceTransformer Model
|
359 |
+
model = construct_model(model_name=model_name, max_seq_length=max_seq_length)
|
360 |
+
|
361 |
+
# Load File Names and Record Volumes
|
362 |
+
with open('data_records.json') as fIn:
|
363 |
+
data_records = json.load(fIn)
|
364 |
+
|
365 |
+
total_pairs = sum(data_records.values())
|
366 |
+
|
367 |
+
logging.info("Total Training Pairs: {}".format(total_pairs))
|
368 |
+
|
369 |
+
# Initialize Data Generators
|
370 |
+
generators = build_generators(data_records=data_records,
|
371 |
+
max_seq_length=max_seq_length,
|
372 |
+
testing=testing)
|
373 |
+
|
374 |
+
logging.info("Data Generators Initialized")
|
375 |
+
|
376 |
+
# Define Training Loss Function
|
377 |
+
train_loss = losses.MultipleNegativesRankingLoss(model,
|
378 |
+
scale=20,
|
379 |
+
similarity_fct=util.dot_score)
|
380 |
+
|
381 |
+
logging.info(print_gpu_utilization())
|
382 |
+
|
383 |
+
#####
|
384 |
+
|
385 |
+
|
386 |
+
# Configure Training Cycles
|
387 |
+
failed_on = None # chunk that the process failed on
|
388 |
+
random.seed(42)
|
389 |
+
steps = 0
|
390 |
+
first_iter = True
|
391 |
+
for cycle_num in range(num_cycles):
|
392 |
+
logging.info("Starting Cycle {}".format(cycle_num+1))
|
393 |
+
for chunk_num in range(num_chunks):
|
394 |
+
if failed_on is not None and (chunk_num+1) < failed_on and (cycle_num+1) == 1:
|
395 |
+
pass
|
396 |
+
else:
|
397 |
+
logging.info("Chunk {}/{}".format(chunk_num+1, num_chunks))
|
398 |
+
logging.info("Loading {} Datasets".format(len([file for file in list(data_records.keys()) if file.startswith('S2ORC') or file.startswith('reddit_')]) if testing else len(data_records)))
|
399 |
+
# t_dataload0 = time.time()
|
400 |
+
# Create the training dataloader for the given chunk of data
|
401 |
+
train_dataloader, dataset_lengths_sum = produce_data(data_records,
|
402 |
+
num_chunks,
|
403 |
+
generators,
|
404 |
+
batch_size,
|
405 |
+
failed_on=failed_on,
|
406 |
+
first_iter=first_iter,
|
407 |
+
testing=testing,
|
408 |
+
temp=temp)
|
409 |
+
first_iter = False
|
410 |
+
# t_dataload1 = time.time()
|
411 |
+
# print(t_dataload1-t_dataload0)
|
412 |
+
|
413 |
+
logging.info(print_gpu_utilization())
|
414 |
+
|
415 |
+
# steps_per_epoch = dataset_lengths_sum // batch_size_pairs
|
416 |
+
|
417 |
+
for epoch_num in range(num_epochs):
|
418 |
+
logging.info("Performing Cycle {}, Chunk {}, Epoch {}".format(cycle_num+1, chunk_num+1, epoch_num+1))
|
419 |
+
try:
|
420 |
+
# t_fit0 = time.time()
|
421 |
+
# Train the model
|
422 |
+
model.fit(train_objectives=[(train_dataloader, train_loss)],
|
423 |
+
evaluator=None,
|
424 |
+
epochs=1,
|
425 |
+
warmup_steps=warmup_steps,
|
426 |
+
steps_per_epoch=steps_per_epoch,
|
427 |
+
use_amp=use_amp,
|
428 |
+
output_path=output_path)
|
429 |
+
# t_fit1 = time.time()
|
430 |
+
# print(t_fit1-t_fit0)
|
431 |
+
|
432 |
+
steps += steps_per_epoch
|
433 |
+
|
434 |
+
logging.info(print_gpu_utilization())
|
435 |
+
logging.info("Succeeded on Cycle {}, Chunk {}, Epoch {}".format(cycle_num+1, chunk_num+1, epoch_num+1))
|
436 |
+
logging.info("{} Steps Completed in Total".format(steps))
|
437 |
+
|
438 |
+
with open('train_logs.txt', 'a') as log:
|
439 |
+
log.write("Succeeded on Cycle {}, Chunk {}, Epoch {}: {} Steps Completed in Total\n".format(cycle_num+1, chunk_num+1, epoch_num+1, steps))
|
440 |
+
|
441 |
+
except:
|
442 |
+
logging.info("Failed on Cycle {}, Chunk {}, Epoch {}".format(cycle_num+1, chunk_num+1, epoch_num+1))
|
443 |
+
|
444 |
+
with open('train_logs.txt', 'a') as log:
|
445 |
+
log.write("Failed on Cycle {}, Chunk {}, Epoch {}: {} Steps Completed in Total\n".format(cycle_num+1, chunk_num+1, epoch_num+1, steps))
|
446 |
+
|
447 |
+
finally:
|
448 |
+
warmup_steps = 0
|
449 |
+
|
450 |
+
# Clear GPU/CUDA memory cache between data chunks
|
451 |
+
train_dataloader = None
|
452 |
+
model = None
|
453 |
+
train_loss = None
|
454 |
+
|
455 |
+
gc.collect()
|
456 |
+
torch.cuda.empty_cache()
|
457 |
+
|
458 |
+
# Reload the model and reinitialize the loss function
|
459 |
+
model = construct_model(model_name='hum-lodestone-v1', max_seq_length=max_seq_length)
|
460 |
+
|
461 |
+
train_loss = losses.MultipleNegativesRankingLoss(model,
|
462 |
+
scale=20,
|
463 |
+
similarity_fct=util.dot_score)
|
464 |
+
|
465 |
+
logging.info(print_gpu_utilization())
|
bert_layers.py
ADDED
@@ -0,0 +1,1072 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML Examples authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
5 |
+
# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
|
6 |
+
# Copyright (c) 2022, Tri Dao.
|
7 |
+
|
8 |
+
"""Implements Mosaic BERT, with an eye towards the Hugging Face API.
|
9 |
+
|
10 |
+
Mosaic BERT improves performance over Hugging Face BERT through the following:
|
11 |
+
|
12 |
+
1. ALiBi. This architectural change removes positional embeddings and instead encodes positional
|
13 |
+
information through attention biases based on query-key position distance. It improves the effectiveness
|
14 |
+
of training with shorter sequence lengths by enabling extrapolation to longer sequences.
|
15 |
+
|
16 |
+
2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer
|
17 |
+
to improve overall expressiveness, providing better convergence properties.
|
18 |
+
|
19 |
+
3. Flash Attention. The Mosaic BERT's self-attention layer makes use of Flash Attention, which dramatically
|
20 |
+
improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that
|
21 |
+
supports attention biases, which allows us to use Flash Attention with ALiBi.
|
22 |
+
|
23 |
+
4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT
|
24 |
+
implementations waste computation on padded tokens. Mosaic BERT internally unpads to reduce unnecessary computation
|
25 |
+
and improve speed. It does this without changing how the user interfaces with the model, thereby
|
26 |
+
preserving the simple API of standard implementations.
|
27 |
+
|
28 |
+
|
29 |
+
Currently, Mosaic BERT is available for masked language modeling :class:`BertForMaskedLM` and sequence
|
30 |
+
classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases.
|
31 |
+
|
32 |
+
See :file:`./mosaic_bert.py` for utilities to simplify working with Mosaic BERT in Composer, and for example usage
|
33 |
+
of the core Mosaic BERT classes.
|
34 |
+
"""
|
35 |
+
|
36 |
+
import copy
|
37 |
+
import logging
|
38 |
+
import math
|
39 |
+
import warnings
|
40 |
+
from typing import List, Optional, Tuple, Union
|
41 |
+
|
42 |
+
import torch
|
43 |
+
import torch.nn as nn
|
44 |
+
from einops import rearrange
|
45 |
+
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
46 |
+
from transformers.activations import ACT2FN
|
47 |
+
from transformers.modeling_outputs import (MaskedLMOutput,
|
48 |
+
SequenceClassifierOutput)
|
49 |
+
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
50 |
+
|
51 |
+
from .bert_padding import (index_first_axis,
|
52 |
+
index_put_first_axis, pad_input,
|
53 |
+
unpad_input, unpad_input_only)
|
54 |
+
|
55 |
+
try:
|
56 |
+
from .flash_attn_triton import flash_attn_qkvpacked_func
|
57 |
+
except ImportError as e:
|
58 |
+
flash_attn_qkvpacked_func = None
|
59 |
+
|
60 |
+
logger = logging.getLogger(__name__)
|
61 |
+
|
62 |
+
|
63 |
+
class BertEmbeddings(nn.Module):
|
64 |
+
"""Construct the embeddings for words, ignoring position.
|
65 |
+
|
66 |
+
There are no positional embeddings since we use ALiBi and token_type
|
67 |
+
embeddings.
|
68 |
+
|
69 |
+
This module is modeled after the Hugging Face BERT's
|
70 |
+
:class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is
|
71 |
+
modified as part of Mosaic BERT's ALiBi implementation. The key change is
|
72 |
+
that position embeddings are removed. Position information instead comes
|
73 |
+
from attention biases that scale linearly with the position distance
|
74 |
+
between query and key tokens.
|
75 |
+
|
76 |
+
This module ignores the `position_ids` input to the `forward` method.
|
77 |
+
"""
|
78 |
+
|
79 |
+
def __init__(self, config):
|
80 |
+
super().__init__()
|
81 |
+
self.word_embeddings = nn.Embedding(config.vocab_size,
|
82 |
+
config.hidden_size,
|
83 |
+
padding_idx=config.pad_token_id)
|
84 |
+
# ALiBi doesn't use position embeddings
|
85 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
|
86 |
+
config.hidden_size)
|
87 |
+
|
88 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model
|
89 |
+
# variable name and be able to load any TensorFlow checkpoint file
|
90 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
91 |
+
eps=config.layer_norm_eps)
|
92 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
93 |
+
self.register_buffer('token_type_ids',
|
94 |
+
torch.zeros(config.max_position_embeddings,
|
95 |
+
dtype=torch.long),
|
96 |
+
persistent=False)
|
97 |
+
|
98 |
+
def forward(
|
99 |
+
self,
|
100 |
+
input_ids: Optional[torch.LongTensor] = None,
|
101 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
102 |
+
position_ids: Optional[torch.LongTensor] = None,
|
103 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
104 |
+
past_key_values_length: int = 0,
|
105 |
+
) -> torch.Tensor:
|
106 |
+
if (input_ids is not None) == (inputs_embeds is not None):
|
107 |
+
raise ValueError('Must specify either input_ids or input_embeds!')
|
108 |
+
if input_ids is not None:
|
109 |
+
input_shape = input_ids.size()
|
110 |
+
else:
|
111 |
+
assert inputs_embeds is not None # just for type checking
|
112 |
+
input_shape = inputs_embeds.size()[:-1]
|
113 |
+
|
114 |
+
seq_length = input_shape[1]
|
115 |
+
|
116 |
+
if position_ids is None:
|
117 |
+
# great! ALiBi
|
118 |
+
pass
|
119 |
+
|
120 |
+
# Setting the token_type_ids to the registered buffer in constructor
|
121 |
+
# where it is all zeros, which usually occurs when it's auto-generated;
|
122 |
+
# registered buffer helps users when tracing the model without passing
|
123 |
+
# token_type_ids, solves issue #5664
|
124 |
+
if token_type_ids is None:
|
125 |
+
if hasattr(self, 'token_type_ids'):
|
126 |
+
assert isinstance(self.token_type_ids, torch.LongTensor)
|
127 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
128 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
|
129 |
+
input_shape[0], seq_length)
|
130 |
+
token_type_ids = buffered_token_type_ids_expanded # type: ignore
|
131 |
+
else:
|
132 |
+
token_type_ids = torch.zeros(input_shape, # type: ignore
|
133 |
+
dtype=torch.long,
|
134 |
+
device=self.word_embeddings.device) # type: ignore # yapf: disable
|
135 |
+
|
136 |
+
if inputs_embeds is None:
|
137 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
138 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
139 |
+
|
140 |
+
embeddings = inputs_embeds + token_type_embeddings
|
141 |
+
# no position embeddings! ALiBi
|
142 |
+
embeddings = self.LayerNorm(embeddings)
|
143 |
+
embeddings = self.dropout(embeddings)
|
144 |
+
return embeddings
|
145 |
+
|
146 |
+
|
147 |
+
class BertUnpadSelfAttention(nn.Module):
|
148 |
+
"""Performs multi-headed self attention on a batch of unpadded sequences.
|
149 |
+
|
150 |
+
If Triton is installed, this module uses Flash Attention to greatly improve throughput.
|
151 |
+
The Flash Attention implementation used in Mosaic BERT supports arbitrary attention biases (which
|
152 |
+
we use to implement ALiBi), but does not support attention dropout. If either Triton is not installed
|
153 |
+
or `config.attention_probs_dropout_prob > 0`, the implementation will default to a
|
154 |
+
math-equivalent pytorch version, which is much slower.
|
155 |
+
|
156 |
+
See `forward` method for additional detail.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, config):
|
160 |
+
super().__init__()
|
161 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
162 |
+
config, 'embedding_size'):
|
163 |
+
raise ValueError(
|
164 |
+
f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
|
165 |
+
f'heads ({config.num_attention_heads})')
|
166 |
+
|
167 |
+
self.num_attention_heads = config.num_attention_heads
|
168 |
+
self.attention_head_size = int(config.hidden_size /
|
169 |
+
config.num_attention_heads)
|
170 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
171 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
172 |
+
self.p_dropout = config.attention_probs_dropout_prob
|
173 |
+
self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
|
174 |
+
|
175 |
+
# Warn if defaulting to pytorch because of import issues
|
176 |
+
if flash_attn_qkvpacked_func is None:
|
177 |
+
warnings.warn(
|
178 |
+
'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
|
179 |
+
)
|
180 |
+
|
181 |
+
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
182 |
+
max_seqlen_in_batch: int, indices: torch.Tensor,
|
183 |
+
attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
|
184 |
+
"""Perform self-attention.
|
185 |
+
|
186 |
+
If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
|
187 |
+
implementation of self-attention.
|
188 |
+
|
189 |
+
The arguments are unpadded, and our implementations of attention require padded arguments,
|
190 |
+
so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers.
|
191 |
+
The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute.
|
192 |
+
It is possible to write an unpadded implementation of attention (in Triton and PyTorch), which we will eventually do.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
hidden_states: (total_nnz, dim)
|
196 |
+
cu_seqlens: (batch + 1,)
|
197 |
+
max_seqlen_in_batch: int
|
198 |
+
indices: (total_nnz,)
|
199 |
+
attn_mask: (batch, max_seqlen_in_batch)
|
200 |
+
bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
attention: (total_nnz, dim)
|
204 |
+
"""
|
205 |
+
qkv = self.Wqkv(hidden_states)
|
206 |
+
qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1,
|
207 |
+
max_seqlen_in_batch) # batch, max_seqlen_in_batch, thd
|
208 |
+
qkv = rearrange(qkv,
|
209 |
+
'b s (t h d) -> b s t h d',
|
210 |
+
t=3,
|
211 |
+
h=self.num_attention_heads)
|
212 |
+
if self.p_dropout or flash_attn_qkvpacked_func is None:
|
213 |
+
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
214 |
+
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
|
215 |
+
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
|
216 |
+
v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
|
217 |
+
attention_scores = torch.matmul(q, k) / math.sqrt(
|
218 |
+
self.attention_head_size)
|
219 |
+
attention_scores = attention_scores + bias
|
220 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
221 |
+
attention_probs = self.dropout(attention_probs)
|
222 |
+
attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
|
223 |
+
3) # b s h d
|
224 |
+
else:
|
225 |
+
# Triton implementation only supports 0 attention dropout
|
226 |
+
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
|
227 |
+
if convert_dtype:
|
228 |
+
# Triton implementation only supports fp16 and bf16
|
229 |
+
orig_dtype = qkv.dtype
|
230 |
+
qkv = qkv.to(torch.float16)
|
231 |
+
bias_dtype = bias.dtype
|
232 |
+
bias = bias.to(torch.float16)
|
233 |
+
attention = flash_attn_qkvpacked_func(qkv, bias)
|
234 |
+
attention = attention.to(orig_dtype)
|
235 |
+
bias = bias.to(bias_dtype)
|
236 |
+
else:
|
237 |
+
attention = flash_attn_qkvpacked_func(qkv, bias)
|
238 |
+
|
239 |
+
# attn_mask is 1 for attend and 0 for don't
|
240 |
+
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
241 |
+
return rearrange(attention, 'nnz h d -> nnz (h d)')
|
242 |
+
|
243 |
+
|
244 |
+
# Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
|
245 |
+
class BertSelfOutput(nn.Module):
|
246 |
+
"""Computes the output of the attention layer.
|
247 |
+
|
248 |
+
This module is modeled after the Hugging Face BERT's
|
249 |
+
:class:`~transformers.model.bert.modeling_bert.BertSelfOutput`.
|
250 |
+
The implementation is identical. Rather than use the original module
|
251 |
+
directly, we re-implement it here so that Mosaic BERT's modules will not
|
252 |
+
be affected by any Composer surgery algorithm that modifies Hugging Face
|
253 |
+
BERT modules.
|
254 |
+
"""
|
255 |
+
|
256 |
+
def __init__(self, config):
|
257 |
+
super().__init__()
|
258 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
259 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
260 |
+
eps=config.layer_norm_eps)
|
261 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
262 |
+
|
263 |
+
def forward(self, hidden_states: torch.Tensor,
|
264 |
+
input_tensor: torch.Tensor) -> torch.Tensor:
|
265 |
+
hidden_states = self.dense(hidden_states)
|
266 |
+
hidden_states = self.dropout(hidden_states)
|
267 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
268 |
+
return hidden_states
|
269 |
+
|
270 |
+
|
271 |
+
class BertUnpadAttention(nn.Module):
|
272 |
+
"""Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
|
273 |
+
|
274 |
+
def __init__(self, config):
|
275 |
+
super().__init__()
|
276 |
+
self.self = BertUnpadSelfAttention(config)
|
277 |
+
self.output = BertSelfOutput(config)
|
278 |
+
|
279 |
+
def forward(
|
280 |
+
self,
|
281 |
+
input_tensor: torch.Tensor,
|
282 |
+
cu_seqlens: torch.Tensor,
|
283 |
+
max_s: int,
|
284 |
+
subset_idx: Optional[torch.Tensor] = None,
|
285 |
+
indices: Optional[torch.Tensor] = None,
|
286 |
+
attn_mask: Optional[torch.Tensor] = None,
|
287 |
+
bias: Optional[torch.Tensor] = None,
|
288 |
+
) -> torch.Tensor:
|
289 |
+
"""Forward pass for scaled self-attention without padding.
|
290 |
+
|
291 |
+
Arguments:
|
292 |
+
input_tensor: (total_nnz, dim)
|
293 |
+
cu_seqlens: (batch + 1,)
|
294 |
+
max_s: int
|
295 |
+
subset_idx: () set of indices whose values we care about at the end of the layer
|
296 |
+
(e.g., the masked tokens, if this is the final layer).
|
297 |
+
indices: None or (total_nnz,)
|
298 |
+
attn_mask: None or (batch, max_seqlen_in_batch)
|
299 |
+
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
300 |
+
"""
|
301 |
+
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
|
302 |
+
attn_mask, bias)
|
303 |
+
if subset_idx is not None:
|
304 |
+
return self.output(index_first_axis(self_output, subset_idx),
|
305 |
+
index_first_axis(input_tensor, subset_idx))
|
306 |
+
else:
|
307 |
+
return self.output(self_output, input_tensor)
|
308 |
+
|
309 |
+
|
310 |
+
class BertGatedLinearUnitMLP(nn.Module):
|
311 |
+
"""Applies the FFN at the end of each Mosaic BERT layer.
|
312 |
+
|
313 |
+
Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
|
314 |
+
and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
|
315 |
+
introduces Gated Linear Units.
|
316 |
+
|
317 |
+
Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
|
318 |
+
standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
|
319 |
+
`config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
|
320 |
+
with the `config.intermediate_size=3072`.
|
321 |
+
However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
|
322 |
+
parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
|
323 |
+
"""
|
324 |
+
|
325 |
+
def __init__(self, config):
|
326 |
+
super().__init__()
|
327 |
+
self.config = config
|
328 |
+
self.gated_layers = nn.Linear(config.hidden_size,
|
329 |
+
config.intermediate_size * 2,
|
330 |
+
bias=False)
|
331 |
+
self.act = nn.GELU(approximate='none')
|
332 |
+
self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
|
333 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
334 |
+
self.layernorm = nn.LayerNorm(config.hidden_size,
|
335 |
+
eps=config.layer_norm_eps)
|
336 |
+
|
337 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
338 |
+
"""Compute new hidden states from current hidden states.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
hidden_states (torch.Tensor): The (unpadded) hidden states from
|
342 |
+
the attention layer [nnz, dim].
|
343 |
+
"""
|
344 |
+
residual_connection = hidden_states
|
345 |
+
# compute the activation
|
346 |
+
hidden_states = self.gated_layers(hidden_states)
|
347 |
+
gated = hidden_states[:, :self.config.intermediate_size]
|
348 |
+
non_gated = hidden_states[:, self.config.intermediate_size:]
|
349 |
+
hidden_states = self.act(gated) * non_gated
|
350 |
+
hidden_states = self.dropout(hidden_states)
|
351 |
+
# multiply by the second matrix
|
352 |
+
hidden_states = self.wo(hidden_states)
|
353 |
+
# add the residual connection and post-LN
|
354 |
+
hidden_states = self.layernorm(hidden_states + residual_connection)
|
355 |
+
return hidden_states
|
356 |
+
|
357 |
+
|
358 |
+
class BertLayer(nn.Module):
|
359 |
+
"""Composes the Mosaic BERT attention and FFN blocks into a single layer."""
|
360 |
+
|
361 |
+
def __init__(self, config):
|
362 |
+
super(BertLayer, self).__init__()
|
363 |
+
self.attention = BertUnpadAttention(config)
|
364 |
+
self.mlp = BertGatedLinearUnitMLP(config)
|
365 |
+
|
366 |
+
def forward(
|
367 |
+
self,
|
368 |
+
hidden_states: torch.Tensor,
|
369 |
+
cu_seqlens: torch.Tensor,
|
370 |
+
seqlen: int,
|
371 |
+
subset_idx: Optional[torch.Tensor] = None,
|
372 |
+
indices: Optional[torch.Tensor] = None,
|
373 |
+
attn_mask: Optional[torch.Tensor] = None,
|
374 |
+
bias: Optional[torch.Tensor] = None,
|
375 |
+
) -> torch.Tensor:
|
376 |
+
"""Forward pass for a BERT layer, including both attention and MLP.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
hidden_states: (total_nnz, dim)
|
380 |
+
cu_seqlens: (batch + 1,)
|
381 |
+
seqlen: int
|
382 |
+
subset_idx: () set of indices whose values we care about at the end of the layer
|
383 |
+
(e.g., the masked tokens, if this is the final layer).
|
384 |
+
indices: None or (total_nnz,)
|
385 |
+
attn_mask: None or (batch, max_seqlen_in_batch)
|
386 |
+
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
|
387 |
+
"""
|
388 |
+
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
|
389 |
+
subset_idx, indices, attn_mask, bias)
|
390 |
+
layer_output = self.mlp(attention_output)
|
391 |
+
return layer_output
|
392 |
+
|
393 |
+
|
394 |
+
class BertEncoder(nn.Module):
|
395 |
+
"""A stack of BERT layers providing the backbone of Mosaic BERT.
|
396 |
+
|
397 |
+
This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`,
|
398 |
+
but with substantial modifications to implement unpadding and ALiBi.
|
399 |
+
|
400 |
+
Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
|
401 |
+
at padded tokens, and pre-computes attention biases to implement ALiBi.
|
402 |
+
"""
|
403 |
+
|
404 |
+
def __init__(self, config):
|
405 |
+
super().__init__()
|
406 |
+
layer = BertLayer(config)
|
407 |
+
self.layer = nn.ModuleList(
|
408 |
+
[copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
409 |
+
|
410 |
+
self.num_attention_heads = config.num_attention_heads
|
411 |
+
|
412 |
+
# The alibi mask will be dynamically expanded if it is too small for
|
413 |
+
# the input the model receives. But it generally helps to initialize it
|
414 |
+
# to a reasonably large size to help pre-allocate CUDA memory.
|
415 |
+
# The default `alibi_starting_size` is 512.
|
416 |
+
self._current_alibi_size = int(config.alibi_starting_size)
|
417 |
+
self.alibi = torch.zeros(
|
418 |
+
(1, self.num_attention_heads, self._current_alibi_size,
|
419 |
+
self._current_alibi_size))
|
420 |
+
self.rebuild_alibi_tensor(size=config.alibi_starting_size)
|
421 |
+
|
422 |
+
def rebuild_alibi_tensor(self,
|
423 |
+
size: int,
|
424 |
+
device: Optional[Union[torch.device, str]] = None):
|
425 |
+
# Alibi
|
426 |
+
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
|
427 |
+
# In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
|
428 |
+
# of the logits, which makes the math work out *after* applying causal masking. If no causal masking
|
429 |
+
# will be applied, it is necessary to construct the diagonal mask.
|
430 |
+
n_heads = self.num_attention_heads
|
431 |
+
|
432 |
+
def _get_alibi_head_slopes(n_heads: int) -> List[float]:
|
433 |
+
|
434 |
+
def get_slopes_power_of_2(n_heads: int) -> List[float]:
|
435 |
+
start = (2**(-2**-(math.log2(n_heads) - 3)))
|
436 |
+
ratio = start
|
437 |
+
return [start * ratio**i for i in range(n_heads)]
|
438 |
+
|
439 |
+
# In the paper, they only train models that have 2^a heads for some a. This function
|
440 |
+
# has some good properties that only occur when the input is a power of 2. To
|
441 |
+
# maintain that even when the number of heads is not a power of 2, we use a
|
442 |
+
# workaround.
|
443 |
+
if math.log2(n_heads).is_integer():
|
444 |
+
return get_slopes_power_of_2(n_heads)
|
445 |
+
|
446 |
+
closest_power_of_2 = 2**math.floor(math.log2(n_heads))
|
447 |
+
slopes_a = get_slopes_power_of_2(closest_power_of_2)
|
448 |
+
slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
|
449 |
+
slopes_b = slopes_b[0::2][:n_heads - closest_power_of_2]
|
450 |
+
return slopes_a + slopes_b
|
451 |
+
|
452 |
+
context_position = torch.arange(size, device=device)[:, None]
|
453 |
+
memory_position = torch.arange(size, device=device)[None, :]
|
454 |
+
relative_position = torch.abs(memory_position - context_position)
|
455 |
+
# [n_heads, max_token_length, max_token_length]
|
456 |
+
relative_position = relative_position.unsqueeze(0).expand(
|
457 |
+
n_heads, -1, -1)
|
458 |
+
slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
|
459 |
+
alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
|
460 |
+
# [1, n_heads, max_token_length, max_token_length]
|
461 |
+
alibi = alibi.unsqueeze(0)
|
462 |
+
assert alibi.shape == torch.Size([1, n_heads, size, size])
|
463 |
+
|
464 |
+
self._current_alibi_size = size
|
465 |
+
self.alibi = alibi
|
466 |
+
|
467 |
+
def forward(
|
468 |
+
self,
|
469 |
+
hidden_states: torch.Tensor,
|
470 |
+
attention_mask: torch.Tensor,
|
471 |
+
output_all_encoded_layers: Optional[bool] = True,
|
472 |
+
subset_mask: Optional[torch.Tensor] = None,
|
473 |
+
) -> List[torch.Tensor]:
|
474 |
+
|
475 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
476 |
+
extended_attention_mask = extended_attention_mask.to(
|
477 |
+
dtype=next(self.parameters()).dtype) # fp16 compatibility
|
478 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
479 |
+
|
480 |
+
attention_mask_bool = attention_mask.bool()
|
481 |
+
batch, seqlen = hidden_states.shape[:2]
|
482 |
+
# Unpad inputs and mask. It will remove tokens that are padded.
|
483 |
+
# Assume ntokens is total number of tokens (padded and non-padded)
|
484 |
+
# and ntokens_unpad is total number of non-padded tokens.
|
485 |
+
# Then unpadding performs the following compression of the inputs:
|
486 |
+
# hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
|
487 |
+
hidden_states, indices, cu_seqlens, _ = unpad_input(
|
488 |
+
hidden_states, attention_mask_bool)
|
489 |
+
|
490 |
+
# Add alibi matrix to extended_attention_mask
|
491 |
+
if self._current_alibi_size < seqlen:
|
492 |
+
# Rebuild the alibi tensor when needed
|
493 |
+
warnings.warn(
|
494 |
+
f'Increasing alibi size from {self._current_alibi_size} to {seqlen}'
|
495 |
+
)
|
496 |
+
self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
|
497 |
+
elif self.alibi.device != hidden_states.device:
|
498 |
+
# Device catch-up
|
499 |
+
self.alibi = self.alibi.to(hidden_states.device)
|
500 |
+
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
|
501 |
+
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
|
502 |
+
alibi_attn_mask = attn_bias + alibi_bias
|
503 |
+
|
504 |
+
all_encoder_layers = []
|
505 |
+
if subset_mask is None:
|
506 |
+
for layer_module in self.layer:
|
507 |
+
hidden_states = layer_module(hidden_states,
|
508 |
+
cu_seqlens,
|
509 |
+
seqlen,
|
510 |
+
None,
|
511 |
+
indices,
|
512 |
+
attn_mask=attention_mask,
|
513 |
+
bias=alibi_attn_mask)
|
514 |
+
if output_all_encoded_layers:
|
515 |
+
all_encoder_layers.append(hidden_states)
|
516 |
+
# Pad inputs and mask. It will insert back zero-padded tokens.
|
517 |
+
# Assume ntokens is total number of tokens (padded and non-padded)
|
518 |
+
# and ntokens_unpad is total number of non-padded tokens.
|
519 |
+
# Then padding performs the following de-compression:
|
520 |
+
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
|
521 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
522 |
+
else:
|
523 |
+
for i in range(len(self.layer) - 1):
|
524 |
+
layer_module = self.layer[i]
|
525 |
+
hidden_states = layer_module(hidden_states,
|
526 |
+
cu_seqlens,
|
527 |
+
seqlen,
|
528 |
+
None,
|
529 |
+
indices,
|
530 |
+
attn_mask=attention_mask,
|
531 |
+
bias=alibi_attn_mask)
|
532 |
+
if output_all_encoded_layers:
|
533 |
+
all_encoder_layers.append(hidden_states)
|
534 |
+
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
|
535 |
+
as_tuple=False).flatten()
|
536 |
+
hidden_states = self.layer[-1](hidden_states,
|
537 |
+
cu_seqlens,
|
538 |
+
seqlen,
|
539 |
+
subset_idx=subset_idx,
|
540 |
+
indices=indices,
|
541 |
+
attn_mask=attention_mask,
|
542 |
+
bias=alibi_attn_mask)
|
543 |
+
|
544 |
+
if not output_all_encoded_layers:
|
545 |
+
all_encoder_layers.append(hidden_states)
|
546 |
+
return all_encoder_layers
|
547 |
+
|
548 |
+
|
549 |
+
class BertPooler(nn.Module):
|
550 |
+
|
551 |
+
def __init__(self, config):
|
552 |
+
super(BertPooler, self).__init__()
|
553 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
554 |
+
self.activation = nn.Tanh()
|
555 |
+
|
556 |
+
def forward(self,
|
557 |
+
hidden_states: torch.Tensor,
|
558 |
+
pool: Optional[bool] = True) -> torch.Tensor:
|
559 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
560 |
+
# to the first token.
|
561 |
+
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
562 |
+
pooled_output = self.dense(first_token_tensor)
|
563 |
+
pooled_output = self.activation(pooled_output)
|
564 |
+
return pooled_output
|
565 |
+
|
566 |
+
|
567 |
+
class BertPredictionHeadTransform(nn.Module):
|
568 |
+
|
569 |
+
def __init__(self, config):
|
570 |
+
super().__init__()
|
571 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
572 |
+
if isinstance(config.hidden_act, str):
|
573 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
574 |
+
else:
|
575 |
+
self.transform_act_fn = config.hidden_act
|
576 |
+
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
|
577 |
+
|
578 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
579 |
+
hidden_states = self.dense(hidden_states)
|
580 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
581 |
+
hidden_states = self.LayerNorm(hidden_states)
|
582 |
+
return hidden_states
|
583 |
+
|
584 |
+
|
585 |
+
class BertModel(BertPreTrainedModel):
|
586 |
+
"""Overall BERT model.
|
587 |
+
|
588 |
+
Args:
|
589 |
+
config: a BertConfig class instance with the configuration to build a new model
|
590 |
+
|
591 |
+
Inputs:
|
592 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
593 |
+
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
594 |
+
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
595 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
596 |
+
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
597 |
+
a `sentence B` token (see BERT paper for more details).
|
598 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
599 |
+
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
600 |
+
input sequence length in the current batch. It's the mask that we typically use for attention when
|
601 |
+
a batch has varying length sentences.
|
602 |
+
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
603 |
+
|
604 |
+
Outputs: Tuple of (encoded_layers, pooled_output)
|
605 |
+
`encoded_layers`: controlled by `output_all_encoded_layers` argument:
|
606 |
+
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
607 |
+
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
608 |
+
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
609 |
+
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
610 |
+
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
611 |
+
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
612 |
+
classifier pretrained on top of the hidden state associated to the first character of the
|
613 |
+
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
|
614 |
+
|
615 |
+
Example usage:
|
616 |
+
```python
|
617 |
+
# Already been converted into WordPiece token ids
|
618 |
+
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
619 |
+
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
620 |
+
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
621 |
+
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
622 |
+
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
623 |
+
model = BertModel(config=config)
|
624 |
+
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
625 |
+
```
|
626 |
+
"""
|
627 |
+
|
628 |
+
def __init__(self, config, add_pooling_layer=True):
|
629 |
+
super(BertModel, self).__init__(config)
|
630 |
+
self.embeddings = BertEmbeddings(config)
|
631 |
+
self.encoder = BertEncoder(config)
|
632 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
633 |
+
self.post_init()
|
634 |
+
|
635 |
+
def get_input_embeddings(self):
|
636 |
+
return self.embeddings.word_embeddings
|
637 |
+
|
638 |
+
def set_input_embeddings(self, value):
|
639 |
+
self.embeddings.word_embeddings = value
|
640 |
+
|
641 |
+
def forward(
|
642 |
+
self,
|
643 |
+
input_ids: torch.Tensor,
|
644 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
645 |
+
attention_mask: Optional[torch.Tensor] = None,
|
646 |
+
position_ids: Optional[torch.Tensor] = None,
|
647 |
+
output_all_encoded_layers: Optional[bool] = False,
|
648 |
+
masked_tokens_mask: Optional[torch.Tensor] = None,
|
649 |
+
**kwargs
|
650 |
+
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
|
651 |
+
if attention_mask is None:
|
652 |
+
attention_mask = torch.ones_like(input_ids)
|
653 |
+
if token_type_ids is None:
|
654 |
+
token_type_ids = torch.zeros_like(input_ids)
|
655 |
+
|
656 |
+
embedding_output = self.embeddings(input_ids, token_type_ids,
|
657 |
+
position_ids)
|
658 |
+
|
659 |
+
subset_mask = []
|
660 |
+
first_col_mask = []
|
661 |
+
|
662 |
+
if masked_tokens_mask is None:
|
663 |
+
subset_mask = None
|
664 |
+
else:
|
665 |
+
first_col_mask = torch.zeros_like(masked_tokens_mask)
|
666 |
+
first_col_mask[:, 0] = True
|
667 |
+
subset_mask = masked_tokens_mask | first_col_mask
|
668 |
+
|
669 |
+
encoder_outputs = self.encoder(
|
670 |
+
embedding_output,
|
671 |
+
attention_mask,
|
672 |
+
output_all_encoded_layers=output_all_encoded_layers,
|
673 |
+
subset_mask=subset_mask)
|
674 |
+
|
675 |
+
if masked_tokens_mask is None:
|
676 |
+
sequence_output = encoder_outputs[-1]
|
677 |
+
pooled_output = self.pooler(
|
678 |
+
sequence_output) if self.pooler is not None else None
|
679 |
+
else:
|
680 |
+
# TD [2022-03-01]: the indexing here is very tricky.
|
681 |
+
attention_mask_bool = attention_mask.bool()
|
682 |
+
subset_idx = subset_mask[attention_mask_bool] # type: ignore
|
683 |
+
sequence_output = encoder_outputs[-1][
|
684 |
+
masked_tokens_mask[attention_mask_bool][subset_idx]]
|
685 |
+
if self.pooler is not None:
|
686 |
+
pool_input = encoder_outputs[-1][
|
687 |
+
first_col_mask[attention_mask_bool][subset_idx]]
|
688 |
+
pooled_output = self.pooler(pool_input, pool=False)
|
689 |
+
else:
|
690 |
+
pooled_output = None
|
691 |
+
|
692 |
+
if not output_all_encoded_layers:
|
693 |
+
encoder_outputs = sequence_output
|
694 |
+
|
695 |
+
if self.pooler is not None:
|
696 |
+
return encoder_outputs, pooled_output
|
697 |
+
|
698 |
+
return encoder_outputs, None
|
699 |
+
|
700 |
+
|
701 |
+
###################
|
702 |
+
# Bert Heads
|
703 |
+
###################
|
704 |
+
class BertLMPredictionHead(nn.Module):
|
705 |
+
|
706 |
+
def __init__(self, config, bert_model_embedding_weights):
|
707 |
+
super().__init__()
|
708 |
+
self.transform = BertPredictionHeadTransform(config)
|
709 |
+
# The output weights are the same as the input embeddings, but there is
|
710 |
+
# an output-only bias for each token.
|
711 |
+
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
712 |
+
bert_model_embedding_weights.size(0))
|
713 |
+
self.decoder.weight = bert_model_embedding_weights
|
714 |
+
|
715 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
716 |
+
hidden_states = self.transform(hidden_states)
|
717 |
+
hidden_states = self.decoder(hidden_states)
|
718 |
+
return hidden_states
|
719 |
+
|
720 |
+
|
721 |
+
class BertOnlyMLMHead(nn.Module):
|
722 |
+
|
723 |
+
def __init__(self, config, bert_model_embedding_weights):
|
724 |
+
super().__init__()
|
725 |
+
self.predictions = BertLMPredictionHead(config,
|
726 |
+
bert_model_embedding_weights)
|
727 |
+
|
728 |
+
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
729 |
+
prediction_scores = self.predictions(sequence_output)
|
730 |
+
return prediction_scores
|
731 |
+
|
732 |
+
|
733 |
+
class BertOnlyNSPHead(nn.Module):
|
734 |
+
|
735 |
+
def __init__(self, config):
|
736 |
+
super().__init__()
|
737 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
738 |
+
|
739 |
+
def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
|
740 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
741 |
+
return seq_relationship_score
|
742 |
+
|
743 |
+
|
744 |
+
#####################
|
745 |
+
# Various Bert models
|
746 |
+
#####################
|
747 |
+
|
748 |
+
|
749 |
+
class BertForPreTraining(BertPreTrainedModel):
|
750 |
+
#TBD: Coming in Future Commit
|
751 |
+
pass
|
752 |
+
|
753 |
+
|
754 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
755 |
+
#TBD: Coming in Future Commit
|
756 |
+
pass
|
757 |
+
|
758 |
+
|
759 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
760 |
+
|
761 |
+
def __init__(self, config):
|
762 |
+
super().__init__(config)
|
763 |
+
|
764 |
+
if config.is_decoder:
|
765 |
+
warnings.warn(
|
766 |
+
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
767 |
+
'bi-directional self-attention.')
|
768 |
+
|
769 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
770 |
+
self.cls = BertOnlyMLMHead(config,
|
771 |
+
self.bert.embeddings.word_embeddings.weight)
|
772 |
+
|
773 |
+
# Initialize weights and apply final processing
|
774 |
+
self.post_init()
|
775 |
+
|
776 |
+
@classmethod
|
777 |
+
def from_composer(cls,
|
778 |
+
pretrained_checkpoint,
|
779 |
+
state_dict=None,
|
780 |
+
cache_dir=None,
|
781 |
+
from_tf=False,
|
782 |
+
config=None,
|
783 |
+
*inputs,
|
784 |
+
**kwargs):
|
785 |
+
"""Load from pre-trained."""
|
786 |
+
model = cls(config, *inputs, **kwargs)
|
787 |
+
if from_tf:
|
788 |
+
raise ValueError(
|
789 |
+
'Mosaic BERT does not support loading TensorFlow weights.')
|
790 |
+
|
791 |
+
state_dict = torch.load(pretrained_checkpoint)
|
792 |
+
# If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
|
793 |
+
consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
|
794 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
795 |
+
strict=False)
|
796 |
+
|
797 |
+
if len(missing_keys) > 0:
|
798 |
+
logger.warning(
|
799 |
+
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
|
800 |
+
)
|
801 |
+
if len(unexpected_keys) > 0:
|
802 |
+
logger.warning(
|
803 |
+
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
|
804 |
+
)
|
805 |
+
|
806 |
+
return model
|
807 |
+
|
808 |
+
def get_output_embeddings(self):
|
809 |
+
return self.cls.predictions.decoder
|
810 |
+
|
811 |
+
def set_output_embeddings(self, new_embeddings):
|
812 |
+
self.cls.predictions.decoder = new_embeddings
|
813 |
+
|
814 |
+
def forward(
|
815 |
+
self,
|
816 |
+
input_ids: Optional[torch.Tensor] = None,
|
817 |
+
attention_mask: Optional[torch.Tensor] = None,
|
818 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
819 |
+
position_ids: Optional[torch.Tensor] = None,
|
820 |
+
head_mask: Optional[torch.Tensor] = None,
|
821 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
822 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
823 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
824 |
+
labels: Optional[torch.Tensor] = None,
|
825 |
+
output_attentions: Optional[bool] = None,
|
826 |
+
output_hidden_states: Optional[bool] = None,
|
827 |
+
return_dict: Optional[bool] = None,
|
828 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
829 |
+
# labels should be a `torch.LongTensor` of shape
|
830 |
+
# `(batch_size, sequence_length)`. These are used for computing the
|
831 |
+
# masked language modeling loss.
|
832 |
+
#
|
833 |
+
# Indices should be in `[-100, 0, ..., config.vocab_size]` (see
|
834 |
+
# `input_ids` docstring) Tokens with indices set to `-100` are ignored
|
835 |
+
# (masked), the loss is only computed for the tokens with labels in `[0,
|
836 |
+
# ..., config.vocab_size]`
|
837 |
+
#
|
838 |
+
# Prediction scores are only computed for masked tokens and the (bs,
|
839 |
+
# seqlen) dimensions are flattened
|
840 |
+
if (input_ids is not None) == (inputs_embeds is not None):
|
841 |
+
raise ValueError('Must specify either input_ids or input_embeds!')
|
842 |
+
|
843 |
+
if labels is None:
|
844 |
+
masked_tokens_mask = None
|
845 |
+
else:
|
846 |
+
masked_tokens_mask = labels > 0
|
847 |
+
|
848 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
849 |
+
|
850 |
+
outputs = self.bert(
|
851 |
+
input_ids,
|
852 |
+
attention_mask=attention_mask,
|
853 |
+
token_type_ids=token_type_ids,
|
854 |
+
position_ids=position_ids,
|
855 |
+
head_mask=head_mask,
|
856 |
+
inputs_embeds=inputs_embeds,
|
857 |
+
encoder_hidden_states=encoder_hidden_states,
|
858 |
+
encoder_attention_mask=encoder_attention_mask,
|
859 |
+
output_attentions=output_attentions,
|
860 |
+
output_hidden_states=output_hidden_states,
|
861 |
+
return_dict=return_dict,
|
862 |
+
masked_tokens_mask=masked_tokens_mask,
|
863 |
+
)
|
864 |
+
|
865 |
+
sequence_output = outputs[0]
|
866 |
+
prediction_scores = self.cls(sequence_output)
|
867 |
+
|
868 |
+
loss = None
|
869 |
+
if labels is not None:
|
870 |
+
# Compute loss
|
871 |
+
loss_fct = nn.CrossEntropyLoss()
|
872 |
+
masked_token_idx = torch.nonzero(labels.flatten() > 0,
|
873 |
+
as_tuple=False).flatten()
|
874 |
+
loss = loss_fct(prediction_scores,
|
875 |
+
labels.flatten()[masked_token_idx])
|
876 |
+
|
877 |
+
assert input_ids is not None, 'Coding error; please open an issue'
|
878 |
+
batch, seqlen = input_ids.shape[:2]
|
879 |
+
prediction_scores = rearrange(index_put_first_axis(
|
880 |
+
prediction_scores, masked_token_idx, batch * seqlen),
|
881 |
+
'(b s) d -> b s d',
|
882 |
+
b=batch)
|
883 |
+
|
884 |
+
if not return_dict:
|
885 |
+
output = (prediction_scores,) + outputs[2:]
|
886 |
+
return ((loss,) + output) if loss is not None else output
|
887 |
+
|
888 |
+
return MaskedLMOutput(
|
889 |
+
loss=loss,
|
890 |
+
logits=prediction_scores,
|
891 |
+
hidden_states=None,
|
892 |
+
attentions=None,
|
893 |
+
)
|
894 |
+
|
895 |
+
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
896 |
+
attention_mask: torch.Tensor,
|
897 |
+
**model_kwargs):
|
898 |
+
input_shape = input_ids.shape
|
899 |
+
effective_batch_size = input_shape[0]
|
900 |
+
|
901 |
+
# add a dummy token
|
902 |
+
if self.config.pad_token_id is None:
|
903 |
+
raise ValueError('The PAD token should be defined for generation')
|
904 |
+
|
905 |
+
attention_mask = torch.cat([
|
906 |
+
attention_mask,
|
907 |
+
attention_mask.new_zeros((attention_mask.shape[0], 1))
|
908 |
+
],
|
909 |
+
dim=-1)
|
910 |
+
dummy_token = torch.full((effective_batch_size, 1),
|
911 |
+
self.config.pad_token_id,
|
912 |
+
dtype=torch.long,
|
913 |
+
device=input_ids.device)
|
914 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
915 |
+
|
916 |
+
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
917 |
+
|
918 |
+
|
919 |
+
class BertForNextSentencePrediction(BertPreTrainedModel):
|
920 |
+
#TBD: Push in future commit
|
921 |
+
pass
|
922 |
+
|
923 |
+
|
924 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
925 |
+
"""Bert Model transformer with a sequence classification/regression head.
|
926 |
+
|
927 |
+
This head is just a linear layer on top of the pooled output. Used for,
|
928 |
+
e.g., GLUE tasks.
|
929 |
+
"""
|
930 |
+
|
931 |
+
def __init__(self, config):
|
932 |
+
super().__init__(config)
|
933 |
+
self.num_labels = config.num_labels
|
934 |
+
self.config = config
|
935 |
+
|
936 |
+
self.bert = BertModel(config)
|
937 |
+
classifier_dropout = (config.classifier_dropout
|
938 |
+
if config.classifier_dropout is not None else
|
939 |
+
config.hidden_dropout_prob)
|
940 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
941 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
942 |
+
|
943 |
+
# Initialize weights and apply final processing
|
944 |
+
self.post_init()
|
945 |
+
|
946 |
+
@classmethod
|
947 |
+
def from_composer(cls,
|
948 |
+
pretrained_checkpoint,
|
949 |
+
state_dict=None,
|
950 |
+
cache_dir=None,
|
951 |
+
from_tf=False,
|
952 |
+
config=None,
|
953 |
+
*inputs,
|
954 |
+
**kwargs):
|
955 |
+
"""Load from pre-trained."""
|
956 |
+
model = cls(config, *inputs, **kwargs)
|
957 |
+
if from_tf:
|
958 |
+
raise ValueError(
|
959 |
+
'Mosaic BERT does not support loading TensorFlow weights.')
|
960 |
+
|
961 |
+
state_dict = torch.load(pretrained_checkpoint)
|
962 |
+
# If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
|
963 |
+
consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
|
964 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
965 |
+
strict=False)
|
966 |
+
|
967 |
+
if len(missing_keys) > 0:
|
968 |
+
logger.warning(
|
969 |
+
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
|
970 |
+
)
|
971 |
+
if len(unexpected_keys) > 0:
|
972 |
+
logger.warning(
|
973 |
+
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
|
974 |
+
)
|
975 |
+
|
976 |
+
return model
|
977 |
+
|
978 |
+
def forward(
|
979 |
+
self,
|
980 |
+
input_ids: Optional[torch.Tensor] = None,
|
981 |
+
attention_mask: Optional[torch.Tensor] = None,
|
982 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
983 |
+
position_ids: Optional[torch.Tensor] = None,
|
984 |
+
head_mask: Optional[torch.Tensor] = None,
|
985 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
986 |
+
labels: Optional[torch.Tensor] = None,
|
987 |
+
output_attentions: Optional[bool] = None,
|
988 |
+
output_hidden_states: Optional[bool] = None,
|
989 |
+
return_dict: Optional[bool] = None,
|
990 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
991 |
+
# labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
992 |
+
# Labels for computing the sequence classification/regression loss.
|
993 |
+
# Indices should be in `[0, ..., config.num_labels - 1]`.
|
994 |
+
# If `config.num_labels == 1` a regression loss is computed
|
995 |
+
# (mean-square loss). If `config.num_labels > 1` a classification loss
|
996 |
+
# is computed (cross-entropy).
|
997 |
+
|
998 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
999 |
+
|
1000 |
+
outputs = self.bert(
|
1001 |
+
input_ids,
|
1002 |
+
attention_mask=attention_mask,
|
1003 |
+
token_type_ids=token_type_ids,
|
1004 |
+
position_ids=position_ids,
|
1005 |
+
head_mask=head_mask,
|
1006 |
+
inputs_embeds=inputs_embeds,
|
1007 |
+
output_attentions=output_attentions,
|
1008 |
+
output_hidden_states=output_hidden_states,
|
1009 |
+
return_dict=return_dict,
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
pooled_output = outputs[1]
|
1013 |
+
|
1014 |
+
pooled_output = self.dropout(pooled_output)
|
1015 |
+
logits = self.classifier(pooled_output)
|
1016 |
+
|
1017 |
+
loss = None
|
1018 |
+
if labels is not None:
|
1019 |
+
# Compute loss
|
1020 |
+
if self.config.problem_type is None:
|
1021 |
+
if self.num_labels == 1:
|
1022 |
+
self.config.problem_type = 'regression'
|
1023 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or
|
1024 |
+
labels.dtype == torch.int):
|
1025 |
+
self.config.problem_type = 'single_label_classification'
|
1026 |
+
else:
|
1027 |
+
self.config.problem_type = 'multi_label_classification'
|
1028 |
+
|
1029 |
+
if self.config.problem_type == 'regression':
|
1030 |
+
loss_fct = nn.MSELoss()
|
1031 |
+
if self.num_labels == 1:
|
1032 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
1033 |
+
else:
|
1034 |
+
loss = loss_fct(logits, labels)
|
1035 |
+
elif self.config.problem_type == 'single_label_classification':
|
1036 |
+
loss_fct = nn.CrossEntropyLoss()
|
1037 |
+
loss = loss_fct(logits.view(-1, self.num_labels),
|
1038 |
+
labels.view(-1))
|
1039 |
+
elif self.config.problem_type == 'multi_label_classification':
|
1040 |
+
loss_fct = nn.BCEWithLogitsLoss()
|
1041 |
+
loss = loss_fct(logits, labels)
|
1042 |
+
|
1043 |
+
if not return_dict:
|
1044 |
+
output = (logits,) + outputs[2:]
|
1045 |
+
return ((loss,) + output) if loss is not None else output
|
1046 |
+
|
1047 |
+
return SequenceClassifierOutput(
|
1048 |
+
loss=loss,
|
1049 |
+
logits=logits,
|
1050 |
+
hidden_states=None,
|
1051 |
+
attentions=None,
|
1052 |
+
)
|
1053 |
+
|
1054 |
+
|
1055 |
+
class BertForMultipleChoice(BertPreTrainedModel):
|
1056 |
+
#TBD: Push in future commit
|
1057 |
+
pass
|
1058 |
+
|
1059 |
+
|
1060 |
+
class BertForTokenClassification(BertPreTrainedModel):
|
1061 |
+
#TBD: Push in future commit
|
1062 |
+
pass
|
1063 |
+
|
1064 |
+
|
1065 |
+
class BertForQuestionAnswering(BertPreTrainedModel):
|
1066 |
+
"""Bert Model with a span classification head.
|
1067 |
+
|
1068 |
+
This is used for extractive question-answering tasks like SQuAD (a linear
|
1069 |
+
layers on top of the hidden states' output to compute `span start logits`
|
1070 |
+
and `span end logits`).
|
1071 |
+
"""
|
1072 |
+
#TBD: Push in future commit
|
bert_padding.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML Examples authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
|
5 |
+
# Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
6 |
+
|
7 |
+
"""Helper functions for padding and unpadding batches.
|
8 |
+
|
9 |
+
These functions are used extensively throughout the Mosaic BERT implementation
|
10 |
+
in `bert_layers.py`.
|
11 |
+
"""
|
12 |
+
|
13 |
+
from typing import Tuple, cast
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from einops import rearrange, repeat
|
18 |
+
|
19 |
+
|
20 |
+
class IndexFirstAxis(torch.autograd.Function):
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def forward(ctx, input: torch.Tensor,
|
24 |
+
indices: torch.Tensor) -> torch.Tensor:
|
25 |
+
"""Get just the values of `input` which are at `indices`.
|
26 |
+
|
27 |
+
Arguments:
|
28 |
+
ctx: the autograd context object
|
29 |
+
input: (b, ...) 2+ dimensional tensor
|
30 |
+
indices: (num_idx) 1D tensor
|
31 |
+
"""
|
32 |
+
ctx.save_for_backward(indices)
|
33 |
+
assert input.ndim >= 2
|
34 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[
|
35 |
+
1:] # type: ignore
|
36 |
+
second_dim = other_shape.numel(
|
37 |
+
) # product of sizes of all but first dimension
|
38 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
39 |
+
return torch.gather(
|
40 |
+
rearrange(input, 'b ... -> b (...)'), # (b, ...) -> (b, second_dim)
|
41 |
+
0,
|
42 |
+
repeat(indices, 'z -> z d',
|
43 |
+
d=second_dim) # (indices,) -> (indices, second_dim)
|
44 |
+
).reshape(-1, *other_shape) # (num_idx, ...)
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
48 |
+
indices, = ctx.saved_tensors
|
49 |
+
assert grad_output.ndim >= 2
|
50 |
+
other_shape = grad_output.shape[1:]
|
51 |
+
grad_output = rearrange(grad_output, 'b ... -> b (...)')
|
52 |
+
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
|
53 |
+
device=grad_output.device,
|
54 |
+
dtype=grad_output.dtype)
|
55 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
56 |
+
# grad_input[indices] = grad_output
|
57 |
+
grad_input.scatter_(0,
|
58 |
+
repeat(indices, 'z -> z d', d=grad_output.shape[1]),
|
59 |
+
grad_output)
|
60 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
61 |
+
|
62 |
+
|
63 |
+
index_first_axis = IndexFirstAxis.apply
|
64 |
+
|
65 |
+
|
66 |
+
class IndexPutFirstAxis(torch.autograd.Function):
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
|
70 |
+
first_axis_dim) -> torch.Tensor:
|
71 |
+
ctx.save_for_backward(indices)
|
72 |
+
assert indices.ndim == 1
|
73 |
+
assert values.ndim >= 2
|
74 |
+
output = torch.zeros(first_axis_dim,
|
75 |
+
*values.shape[1:],
|
76 |
+
device=values.device,
|
77 |
+
dtype=values.dtype)
|
78 |
+
output[indices] = values
|
79 |
+
return output
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def backward(ctx,
|
83 |
+
grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
|
84 |
+
indices, = ctx.saved_tensors
|
85 |
+
grad_values = grad_output[indices]
|
86 |
+
return grad_values, None, None
|
87 |
+
|
88 |
+
|
89 |
+
index_put_first_axis = IndexPutFirstAxis.apply
|
90 |
+
|
91 |
+
|
92 |
+
def unpad_input(
|
93 |
+
hidden_states: torch.Tensor,
|
94 |
+
attention_mask: torch.Tensor,
|
95 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
96 |
+
"""Remove padding from input sequences.
|
97 |
+
|
98 |
+
Arguments:
|
99 |
+
hidden_states: (batch, seqlen, ...)
|
100 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
104 |
+
indices: (total_nnz)
|
105 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
106 |
+
max_seqlen_in_batch: int ()
|
107 |
+
"""
|
108 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
109 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
110 |
+
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
|
111 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
|
112 |
+
(1, 0))
|
113 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
114 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
115 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
116 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
117 |
+
# so we write custom forward and backward to make it a bit faster.
|
118 |
+
hidden_states = cast(
|
119 |
+
torch.Tensor,
|
120 |
+
index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
|
121 |
+
indices))
|
122 |
+
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
|
123 |
+
|
124 |
+
|
125 |
+
def unpad_input_only(
|
126 |
+
hidden_states: torch.Tensor,
|
127 |
+
attention_mask: torch.Tensor,
|
128 |
+
) -> torch.Tensor:
|
129 |
+
"""Like unpad_input, but only return the unpadded first tensor.
|
130 |
+
|
131 |
+
Save a small amount of overhead.
|
132 |
+
|
133 |
+
Arguments:
|
134 |
+
hidden_states: (batch, seqlen, ...)
|
135 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
139 |
+
"""
|
140 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
141 |
+
return index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
|
142 |
+
indices)
|
143 |
+
|
144 |
+
|
145 |
+
def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
|
146 |
+
seqlen: int) -> torch.Tensor:
|
147 |
+
"""Add padding to sequences.
|
148 |
+
|
149 |
+
Arguments:
|
150 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
151 |
+
indices: (total_nnz)
|
152 |
+
batch: int batch_size
|
153 |
+
seqlen: int max sequence length
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
hidden_states: (batch, seqlen, ...)
|
157 |
+
"""
|
158 |
+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
159 |
+
return rearrange(output, '(b s) ... -> b s ...', b=batch)
|
config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "hum-lodestone-v1",
|
3 |
+
"alibi_starting_size": 4096,
|
4 |
+
"architectures": [
|
5 |
+
"BertModel"
|
6 |
+
],
|
7 |
+
"attention_probs_dropout_prob": 0.0,
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_bert.BertConfig",
|
10 |
+
"AutoModel": "bert_layers.BertModel",
|
11 |
+
"AutoModelForMaskedLM": "bert_layers.BertForMaskedLM"
|
12 |
+
},
|
13 |
+
"classifier_dropout": null,
|
14 |
+
"gradient_checkpointing": false,
|
15 |
+
"hidden_act": "gelu",
|
16 |
+
"hidden_dropout_prob": 0.1,
|
17 |
+
"hidden_size": 768,
|
18 |
+
"initializer_range": 0.02,
|
19 |
+
"intermediate_size": 3072,
|
20 |
+
"layer_norm_eps": 1e-12,
|
21 |
+
"max_position_embeddings": 512,
|
22 |
+
"model_type": "bert",
|
23 |
+
"num_attention_heads": 12,
|
24 |
+
"num_hidden_layers": 12,
|
25 |
+
"pad_token_id": 0,
|
26 |
+
"position_embedding_type": "absolute",
|
27 |
+
"tokenizer_class": "BertTokenizerFast",
|
28 |
+
"torch_dtype": "bfloat16",
|
29 |
+
"transformers_version": "4.28.1",
|
30 |
+
"type_vocab_size": 2,
|
31 |
+
"use_cache": true,
|
32 |
+
"vocab_size": 30528
|
33 |
+
}
|
config_sentence_transformers.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "2.2.2",
|
4 |
+
"transformers": "4.28.1",
|
5 |
+
"pytorch": "2.0.1+cu117"
|
6 |
+
}
|
7 |
+
}
|
configuration_bert.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML Examples authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from transformers import BertConfig as TransformersBertConfig
|
5 |
+
|
6 |
+
|
7 |
+
class BertConfig(TransformersBertConfig):
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
alibi_starting_size: int = 512,
|
12 |
+
attention_probs_dropout_prob: float = 0.0,
|
13 |
+
**kwargs,
|
14 |
+
):
|
15 |
+
"""Configuration class for MosaicBert.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
|
19 |
+
create when initializing the model. You should be able to ignore this parameter in most cases.
|
20 |
+
Defaults to 512.
|
21 |
+
attention_probs_dropout_prob (float): By default, turn off attention dropout in Mosaic BERT
|
22 |
+
(otherwise, Flash Attention will be off by default). Defaults to 0.0.
|
23 |
+
"""
|
24 |
+
super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
|
25 |
+
self.alibi_starting_size = alibi_starting_size
|
data_records.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"AllNLI.json.gz": 277230, "CodeSearchNet.json.gz": 1375067, "NQ-train_pairs.json.gz": 100231, "PAQ_pairs.json.gz": 64371441, "S2ORC_citation_pairs.json.gz": 52603982, "S2ORC_citations_abstracts.json.gz": 252102397, "S2ORC_title_abstract.json.gz": 41769185, "SimpleWiki.json.gz": 102225, "TriviaQA_pairs.json.gz": 73346, "WikiAnswers_pairs.json.gz": 77427422, "agnews.json.gz": 1157745, "altlex.json.gz": 112696, "amazon-qa.json.gz": 2507114, "amazon_review_2018.json.gz": 87877725, "ccnews_title_text.json.gz": 614664, "cnn_dailymail.json.gz": 311971, "coco_captions.json.gz": 828395, "eli5_question_answer.json.gz": 325475, "fever_train.json.gz": 139051, "flickr30k_captions.json.gz": 317695, "gooaq_pairs.json.gz": 3012496, "msmarco-query_passage.json.gz": 532751, "msmarco-query_passage_negative.json.gz": 9144553, "npr.json.gz": 594384, "quora_duplicates.json.gz": 103663, "quora_duplicates_triplets.json.gz": 103663, "reddit-title-body/reddit_title_text_2010.json.gz": 431782, "reddit-title-body/reddit_title_text_2011.json.gz": 1673264, "reddit-title-body/reddit_title_text_2012.json.gz": 3727526, "reddit-title-body/reddit_title_text_2013.json.gz": 5713956, "reddit-title-body/reddit_title_text_2014.json.gz": 8538976, "reddit-title-body/reddit_title_text_2015.json.gz": 11064453, "reddit-title-body/reddit_title_text_2016.json.gz": 12224789, "reddit-title-body/reddit_title_text_2017.json.gz": 13558139, "reddit-title-body/reddit_title_text_2018.json.gz": 15552110, "reddit-title-body/reddit_title_text_2019.json.gz": 19224970, "reddit-title-body/reddit_title_text_2020.json.gz": 23030988, "reddit-title-body/reddit_title_text_2021.json.gz": 12704958, "reddit_2015.json.gz": 135108166, "reddit_2016.json.gz": 159164386, "reddit_2017.json.gz": 191485219, "reddit_2018.json.gz": 240726659, "searchQA_question_top5_snippets_merged.json.gz": 582261, "searchQA_question_topSnippet.json.gz": 117384, "sentence-compression.json.gz": 180000, "specter_train_triples.json.gz": 684100, "squad_pairs.json.gz": 87599, "stackexchange_duplicate_questions_body_body.json.gz": 250459, "stackexchange_duplicate_questions_title-body_title-body.json.gz": 250518, "stackexchange_duplicate_questions_title_title.json.gz": 304524, "stackexchange_title_best_voted_answer_jsonl/3dprinting.stackexchange.com.json.gz": 3488, "stackexchange_title_best_voted_answer_jsonl/academia.stackexchange.com.json.gz": 32137, "stackexchange_title_best_voted_answer_jsonl/ai.stackexchange.com.json.gz": 5763, "stackexchange_title_best_voted_answer_jsonl/android.stackexchange.com.json.gz": 38077, "stackexchange_title_best_voted_answer_jsonl/anime.stackexchange.com.json.gz": 10131, "stackexchange_title_best_voted_answer_jsonl/apple.stackexchange.com.json.gz": 92487, "stackexchange_title_best_voted_answer_jsonl/arduino.stackexchange.com.json.gz": 16281, "stackexchange_title_best_voted_answer_jsonl/askubuntu.com.json.gz": 267135, "stackexchange_title_best_voted_answer_jsonl/astronomy.stackexchange.com.json.gz": 9086, "stackexchange_title_best_voted_answer_jsonl/aviation.stackexchange.com.json.gz": 18755, "stackexchange_title_best_voted_answer_jsonl/avp.stackexchange.com.json.gz": 6450, "stackexchange_title_best_voted_answer_jsonl/beer.stackexchange.com.json.gz": 1012, "stackexchange_title_best_voted_answer_jsonl/bicycles.stackexchange.com.json.gz": 15708, "stackexchange_title_best_voted_answer_jsonl/bioinformatics.stackexchange.com.json.gz": 3135, "stackexchange_title_best_voted_answer_jsonl/biology.stackexchange.com.json.gz": 19277, "stackexchange_title_best_voted_answer_jsonl/bitcoin.stackexchange.com.json.gz": 22474, "stackexchange_title_best_voted_answer_jsonl/blender.stackexchange.com.json.gz": 54153, "stackexchange_title_best_voted_answer_jsonl/boardgames.stackexchange.com.json.gz": 11805, "stackexchange_title_best_voted_answer_jsonl/bricks.stackexchange.com.json.gz": 3530, "stackexchange_title_best_voted_answer_jsonl/buddhism.stackexchange.com.json.gz": 6787, "stackexchange_title_best_voted_answer_jsonl/cardano.stackexchange.com.json.gz": 248, "stackexchange_title_best_voted_answer_jsonl/chemistry.stackexchange.com.json.gz": 27061, "stackexchange_title_best_voted_answer_jsonl/chess.stackexchange.com.json.gz": 6392, "stackexchange_title_best_voted_answer_jsonl/chinese.stackexchange.com.json.gz": 8646, "stackexchange_title_best_voted_answer_jsonl/christianity.stackexchange.com.json.gz": 11498, "stackexchange_title_best_voted_answer_jsonl/civicrm.stackexchange.com.json.gz": 10648, "stackexchange_title_best_voted_answer_jsonl/codegolf.stackexchange.com.json.gz": 8211, "stackexchange_title_best_voted_answer_jsonl/codereview.stackexchange.com.json.gz": 41748, "stackexchange_title_best_voted_answer_jsonl/coffee.stackexchange.com.json.gz": 1188, "stackexchange_title_best_voted_answer_jsonl/cogsci.stackexchange.com.json.gz": 5101, "stackexchange_title_best_voted_answer_jsonl/computergraphics.stackexchange.com.json.gz": 2306, "stackexchange_title_best_voted_answer_jsonl/conlang.stackexchange.com.json.gz": 334, "stackexchange_title_best_voted_answer_jsonl/cooking.stackexchange.com.json.gz": 22641, "stackexchange_title_best_voted_answer_jsonl/craftcms.stackexchange.com.json.gz": 11236, "stackexchange_title_best_voted_answer_jsonl/crafts.stackexchange.com.json.gz": 1659, "stackexchange_title_best_voted_answer_jsonl/crypto.stackexchange.com.json.gz": 19404, "stackexchange_title_best_voted_answer_jsonl/cs.stackexchange.com.json.gz": 30010, "stackexchange_title_best_voted_answer_jsonl/cseducators.stackexchange.com.json.gz": 902, "stackexchange_title_best_voted_answer_jsonl/cstheory.stackexchange.com.json.gz": 7742, "stackexchange_title_best_voted_answer_jsonl/datascience.stackexchange.com.json.gz": 20503, "stackexchange_title_best_voted_answer_jsonl/dba.stackexchange.com.json.gz": 71449, "stackexchange_title_best_voted_answer_jsonl/devops.stackexchange.com.json.gz": 3462, "stackexchange_title_best_voted_answer_jsonl/diy.stackexchange.com.json.gz": 52896, "stackexchange_title_best_voted_answer_jsonl/drones.stackexchange.com.json.gz": 496, "stackexchange_title_best_voted_answer_jsonl/drupal.stackexchange.com.json.gz": 67817, "stackexchange_title_best_voted_answer_jsonl/dsp.stackexchange.com.json.gz": 17430, "stackexchange_title_best_voted_answer_jsonl/earthscience.stackexchange.com.json.gz": 4396, "stackexchange_title_best_voted_answer_jsonl/ebooks.stackexchange.com.json.gz": 1107, "stackexchange_title_best_voted_answer_jsonl/economics.stackexchange.com.json.gz": 8844, "stackexchange_title_best_voted_answer_jsonl/electronics.stackexchange.com.json.gz": 129494, "stackexchange_title_best_voted_answer_jsonl/ell.stackexchange.com.json.gz": 77892, "stackexchange_title_best_voted_answer_jsonl/emacs.stackexchange.com.json.gz": 16830, "stackexchange_title_best_voted_answer_jsonl/engineering.stackexchange.com.json.gz": 8649, "stackexchange_title_best_voted_answer_jsonl/english.stackexchange.com.json.gz": 100640, "stackexchange_title_best_voted_answer_jsonl/eosio.stackexchange.com.json.gz": 1940, "stackexchange_title_best_voted_answer_jsonl/esperanto.stackexchange.com.json.gz": 1466, "stackexchange_title_best_voted_answer_jsonl/ethereum.stackexchange.com.json.gz": 26124, "stackexchange_title_best_voted_answer_jsonl/expatriates.stackexchange.com.json.gz": 4913, "stackexchange_title_best_voted_answer_jsonl/expressionengine.stackexchange.com.json.gz": 10742, "stackexchange_title_best_voted_answer_jsonl/fitness.stackexchange.com.json.gz": 8297, "stackexchange_title_best_voted_answer_jsonl/freelancing.stackexchange.com.json.gz": 1663, "stackexchange_title_best_voted_answer_jsonl/french.stackexchange.com.json.gz": 10578, "stackexchange_title_best_voted_answer_jsonl/gamedev.stackexchange.com.json.gz": 40154, "stackexchange_title_best_voted_answer_jsonl/gaming.stackexchange.com.json.gz": 82887, "stackexchange_title_best_voted_answer_jsonl/gardening.stackexchange.com.json.gz": 13246, "stackexchange_title_best_voted_answer_jsonl/genealogy.stackexchange.com.json.gz": 2895, "stackexchange_title_best_voted_answer_jsonl/german.stackexchange.com.json.gz": 13733, "stackexchange_title_best_voted_answer_jsonl/gis.stackexchange.com.json.gz": 100254, "stackexchange_title_best_voted_answer_jsonl/graphicdesign.stackexchange.com.json.gz": 28083, "stackexchange_title_best_voted_answer_jsonl/ham.stackexchange.com.json.gz": 3501, "stackexchange_title_best_voted_answer_jsonl/hardwarerecs.stackexchange.com.json.gz": 2050, "stackexchange_title_best_voted_answer_jsonl/health.stackexchange.com.json.gz": 4494, "stackexchange_title_best_voted_answer_jsonl/hermeneutics.stackexchange.com.json.gz": 9516, "stackexchange_title_best_voted_answer_jsonl/hinduism.stackexchange.com.json.gz": 8999, "stackexchange_title_best_voted_answer_jsonl/history.stackexchange.com.json.gz": 10766, "stackexchange_title_best_voted_answer_jsonl/homebrew.stackexchange.com.json.gz": 5608, "stackexchange_title_best_voted_answer_jsonl/hsm.stackexchange.com.json.gz": 2517, "stackexchange_title_best_voted_answer_jsonl/interpersonal.stackexchange.com.json.gz": 3398, "stackexchange_title_best_voted_answer_jsonl/iot.stackexchange.com.json.gz": 1359, "stackexchange_title_best_voted_answer_jsonl/iota.stackexchange.com.json.gz": 775, "stackexchange_title_best_voted_answer_jsonl/islam.stackexchange.com.json.gz": 10052, "stackexchange_title_best_voted_answer_jsonl/italian.stackexchange.com.json.gz": 3101, "stackexchange_title_best_voted_answer_jsonl/ja.stackoverflow.com.json.gz": 17376, "stackexchange_title_best_voted_answer_jsonl/japanese.stackexchange.com.json.gz": 20948, "stackexchange_title_best_voted_answer_jsonl/joomla.stackexchange.com.json.gz": 5887, "stackexchange_title_best_voted_answer_jsonl/judaism.stackexchange.com.json.gz": 26085, "stackexchange_title_best_voted_answer_jsonl/korean.stackexchange.com.json.gz": 1406, "stackexchange_title_best_voted_answer_jsonl/languagelearning.stackexchange.com.json.gz": 948, "stackexchange_title_best_voted_answer_jsonl/latin.stackexchange.com.json.gz": 3969, "stackexchange_title_best_voted_answer_jsonl/law.stackexchange.com.json.gz": 16133, "stackexchange_title_best_voted_answer_jsonl/lifehacks.stackexchange.com.json.gz": 2576, "stackexchange_title_best_voted_answer_jsonl/linguistics.stackexchange.com.json.gz": 6843, "stackexchange_title_best_voted_answer_jsonl/literature.stackexchange.com.json.gz": 3539, "stackexchange_title_best_voted_answer_jsonl/magento.stackexchange.com.json.gz": 79241, "stackexchange_title_best_voted_answer_jsonl/martialarts.stackexchange.com.json.gz": 1737, "stackexchange_title_best_voted_answer_jsonl/materials.stackexchange.com.json.gz": 1101, "stackexchange_title_best_voted_answer_jsonl/matheducators.stackexchange.com.json.gz": 2706, "stackexchange_title_best_voted_answer_jsonl/mathematica.stackexchange.com.json.gz": 59895, "stackexchange_title_best_voted_answer_jsonl/mathoverflow.net.json.gz": 85289, "stackexchange_title_best_voted_answer_jsonl/mechanics.stackexchange.com.json.gz": 18613, "stackexchange_title_best_voted_answer_jsonl/meta.askubuntu.com.json.gz": 4268, "stackexchange_title_best_voted_answer_jsonl/meta.mathoverflow.net.json.gz": 1000, "stackexchange_title_best_voted_answer_jsonl/meta.serverfault.com.json.gz": 1726, "stackexchange_title_best_voted_answer_jsonl/meta.stackexchange.com.json.gz": 60744, "stackexchange_title_best_voted_answer_jsonl/meta.stackoverflow.com.json.gz": 24044, "stackexchange_title_best_voted_answer_jsonl/meta.superuser.com.json.gz": 3629, "stackexchange_title_best_voted_answer_jsonl/moderators.stackexchange.com.json.gz": 504, "stackexchange_title_best_voted_answer_jsonl/money.stackexchange.com.json.gz": 29404, "stackexchange_title_best_voted_answer_jsonl/movies.stackexchange.com.json.gz": 18243, "stackexchange_title_best_voted_answer_jsonl/music.stackexchange.com.json.gz": 19936, "stackexchange_title_best_voted_answer_jsonl/musicfans.stackexchange.com.json.gz": 2431, "stackexchange_title_best_voted_answer_jsonl/mythology.stackexchange.com.json.gz": 1595, "stackexchange_title_best_voted_answer_jsonl/networkengineering.stackexchange.com.json.gz": 12590, "stackexchange_title_best_voted_answer_jsonl/opendata.stackexchange.com.json.gz": 3842, "stackexchange_title_best_voted_answer_jsonl/opensource.stackexchange.com.json.gz": 3221, "stackexchange_title_best_voted_answer_jsonl/or.stackexchange.com.json.gz": 1490, "stackexchange_title_best_voted_answer_jsonl/outdoors.stackexchange.com.json.gz": 5278, "stackexchange_title_best_voted_answer_jsonl/parenting.stackexchange.com.json.gz": 5998, "stackexchange_title_best_voted_answer_jsonl/patents.stackexchange.com.json.gz": 3573, "stackexchange_title_best_voted_answer_jsonl/pets.stackexchange.com.json.gz": 6156, "stackexchange_title_best_voted_answer_jsonl/philosophy.stackexchange.com.json.gz": 13114, "stackexchange_title_best_voted_answer_jsonl/photo.stackexchange.com.json.gz": 23204, "stackexchange_title_best_voted_answer_jsonl/physics.stackexchange.com.json.gz": 141230, "stackexchange_title_best_voted_answer_jsonl/pm.stackexchange.com.json.gz": 5435, "stackexchange_title_best_voted_answer_jsonl/poker.stackexchange.com.json.gz": 1665, "stackexchange_title_best_voted_answer_jsonl/politics.stackexchange.com.json.gz": 11047, "stackexchange_title_best_voted_answer_jsonl/portuguese.stackexchange.com.json.gz": 1964, "stackexchange_title_best_voted_answer_jsonl/pt.stackoverflow.com.json.gz": 103277, "stackexchange_title_best_voted_answer_jsonl/puzzling.stackexchange.com.json.gz": 17448, "stackexchange_title_best_voted_answer_jsonl/quant.stackexchange.com.json.gz": 12933, "stackexchange_title_best_voted_answer_jsonl/quantumcomputing.stackexchange.com.json.gz": 4320, "stackexchange_title_best_voted_answer_jsonl/raspberrypi.stackexchange.com.json.gz": 24143, "stackexchange_title_best_voted_answer_jsonl/retrocomputing.stackexchange.com.json.gz": 3907, "stackexchange_title_best_voted_answer_jsonl/reverseengineering.stackexchange.com.json.gz": 5817, "stackexchange_title_best_voted_answer_jsonl/robotics.stackexchange.com.json.gz": 4648, "stackexchange_title_best_voted_answer_jsonl/rpg.stackexchange.com.json.gz": 40435, "stackexchange_title_best_voted_answer_jsonl/ru.stackoverflow.com.json.gz": 253289, "stackexchange_title_best_voted_answer_jsonl/rus.stackexchange.com.json.gz": 16528, "stackexchange_title_best_voted_answer_jsonl/russian.stackexchange.com.json.gz": 3937, "stackexchange_title_best_voted_answer_jsonl/salesforce.stackexchange.com.json.gz": 87272, "stackexchange_title_best_voted_answer_jsonl/scicomp.stackexchange.com.json.gz": 7036, "stackexchange_title_best_voted_answer_jsonl/scifi.stackexchange.com.json.gz": 54805, "stackexchange_title_best_voted_answer_jsonl/serverfault.com.json.gz": 238507, "stackexchange_title_best_voted_answer_jsonl/sharepoint.stackexchange.com.json.gz": 80420, "stackexchange_title_best_voted_answer_jsonl/sitecore.stackexchange.com.json.gz": 7838, "stackexchange_title_best_voted_answer_jsonl/skeptics.stackexchange.com.json.gz": 8145, "stackexchange_title_best_voted_answer_jsonl/softwareengineering.stackexchange.com.json.gz": 51326, "stackexchange_title_best_voted_answer_jsonl/softwarerecs.stackexchange.com.json.gz": 11761, "stackexchange_title_best_voted_answer_jsonl/sound.stackexchange.com.json.gz": 8303, "stackexchange_title_best_voted_answer_jsonl/space.stackexchange.com.json.gz": 12893, "stackexchange_title_best_voted_answer_jsonl/spanish.stackexchange.com.json.gz": 7675, "stackexchange_title_best_voted_answer_jsonl/sports.stackexchange.com.json.gz": 4707, "stackexchange_title_best_voted_answer_jsonl/sqa.stackexchange.com.json.gz": 9256, "stackexchange_title_best_voted_answer_jsonl/stackapps.com.json.gz": 1518, "stackexchange_title_best_voted_answer_jsonl/stats.stackexchange.com.json.gz": 115679, "stackexchange_title_best_voted_answer_jsonl/stellar.stackexchange.com.json.gz": 1078, "stackexchange_title_best_voted_answer_jsonl/superuser.com.json.gz": 352610, "stackexchange_title_best_voted_answer_jsonl/sustainability.stackexchange.com.json.gz": 1674, "stackexchange_title_best_voted_answer_jsonl/tex.stackexchange.com.json.gz": 171628, "stackexchange_title_best_voted_answer_jsonl/tezos.stackexchange.com.json.gz": 1169, "stackexchange_title_best_voted_answer_jsonl/tor.stackexchange.com.json.gz": 4167, "stackexchange_title_best_voted_answer_jsonl/travel.stackexchange.com.json.gz": 36533, "stackexchange_title_best_voted_answer_jsonl/tridion.stackexchange.com.json.gz": 5907, "stackexchange_title_best_voted_answer_jsonl/ukrainian.stackexchange.com.json.gz": 1767, "stackexchange_title_best_voted_answer_jsonl/unix.stackexchange.com.json.gz": 155414, "stackexchange_title_best_voted_answer_jsonl/ux.stackexchange.com.json.gz": 28901, "stackexchange_title_best_voted_answer_jsonl/vegetarianism.stackexchange.com.json.gz": 585, "stackexchange_title_best_voted_answer_jsonl/vi.stackexchange.com.json.gz": 9000, "stackexchange_title_best_voted_answer_jsonl/webapps.stackexchange.com.json.gz": 24867, "stackexchange_title_best_voted_answer_jsonl/webmasters.stackexchange.com.json.gz": 30370, "stackexchange_title_best_voted_answer_jsonl/windowsphone.stackexchange.com.json.gz": 2807, "stackexchange_title_best_voted_answer_jsonl/woodworking.stackexchange.com.json.gz": 2955, "stackexchange_title_best_voted_answer_jsonl/wordpress.stackexchange.com.json.gz": 83621, "stackexchange_title_best_voted_answer_jsonl/workplace.stackexchange.com.json.gz": 24012, "stackexchange_title_best_voted_answer_jsonl/worldbuilding.stackexchange.com.json.gz": 26210, "stackexchange_title_best_voted_answer_jsonl/writers.stackexchange.com.json.gz": 9867, "stackexchange_title_body_jsonl/academia.stackexchange.com.json.gz": 34331, "stackexchange_title_body_jsonl/android.stackexchange.com.json.gz": 51608, "stackexchange_title_body_jsonl/anime.stackexchange.com.json.gz": 11444, "stackexchange_title_body_jsonl/apple.stackexchange.com.json.gz": 110622, "stackexchange_title_body_jsonl/arduino.stackexchange.com.json.gz": 19553, "stackexchange_title_body_jsonl/askubuntu.com.json.gz": 347925, "stackexchange_title_body_jsonl/astronomy.stackexchange.com.json.gz": 10462, "stackexchange_title_body_jsonl/aviation.stackexchange.com.json.gz": 20139, "stackexchange_title_body_jsonl/bicycles.stackexchange.com.json.gz": 16353, "stackexchange_title_body_jsonl/biology.stackexchange.com.json.gz": 24447, "stackexchange_title_body_jsonl/bitcoin.stackexchange.com.json.gz": 25374, "stackexchange_title_body_jsonl/blender.stackexchange.com.json.gz": 80766, "stackexchange_title_body_jsonl/boardgames.stackexchange.com.json.gz": 12149, "stackexchange_title_body_jsonl/chemistry.stackexchange.com.json.gz": 34506, "stackexchange_title_body_jsonl/christianity.stackexchange.com.json.gz": 12108, "stackexchange_title_body_jsonl/civicrm.stackexchange.com.json.gz": 12543, "stackexchange_title_body_jsonl/codereview.stackexchange.com.json.gz": 45765, "stackexchange_title_body_jsonl/cooking.stackexchange.com.json.gz": 23705, "stackexchange_title_body_jsonl/craftcms.stackexchange.com.json.gz": 12574, "stackexchange_title_body_jsonl/crypto.stackexchange.com.json.gz": 23231, "stackexchange_title_body_jsonl/cs.stackexchange.com.json.gz": 38314, "stackexchange_title_body_jsonl/cstheory.stackexchange.com.json.gz": 10642, "stackexchange_title_body_jsonl/datascience.stackexchange.com.json.gz": 27397, "stackexchange_title_body_jsonl/dba.stackexchange.com.json.gz": 81871, "stackexchange_title_body_jsonl/diy.stackexchange.com.json.gz": 60083, "stackexchange_title_body_jsonl/drupal.stackexchange.com.json.gz": 79717, "stackexchange_title_body_jsonl/dsp.stackexchange.com.json.gz": 21252, "stackexchange_title_body_jsonl/economics.stackexchange.com.json.gz": 11115, "stackexchange_title_body_jsonl/electronics.stackexchange.com.json.gz": 143582, "stackexchange_title_body_jsonl/ell.stackexchange.com.json.gz": 83271, "stackexchange_title_body_jsonl/emacs.stackexchange.com.json.gz": 21055, "stackexchange_title_body_jsonl/engineering.stackexchange.com.json.gz": 10753, "stackexchange_title_body_jsonl/english.stackexchange.com.json.gz": 109522, "stackexchange_title_body_jsonl/ethereum.stackexchange.com.json.gz": 32760, "stackexchange_title_body_jsonl/expressionengine.stackexchange.com.json.gz": 11866, "stackexchange_title_body_jsonl/french.stackexchange.com.json.gz": 10794, "stackexchange_title_body_jsonl/gamedev.stackexchange.com.json.gz": 46485, "stackexchange_title_body_jsonl/gaming.stackexchange.com.json.gz": 88912, "stackexchange_title_body_jsonl/gardening.stackexchange.com.json.gz": 15136, "stackexchange_title_body_jsonl/german.stackexchange.com.json.gz": 13950, "stackexchange_title_body_jsonl/gis.stackexchange.com.json.gz": 131000, "stackexchange_title_body_jsonl/graphicdesign.stackexchange.com.json.gz": 30233, "stackexchange_title_body_jsonl/hinduism.stackexchange.com.json.gz": 13450, "stackexchange_title_body_jsonl/history.stackexchange.com.json.gz": 12021, "stackexchange_title_body_jsonl/islam.stackexchange.com.json.gz": 11853, "stackexchange_title_body_jsonl/japanese.stackexchange.com.json.gz": 22056, "stackexchange_title_body_jsonl/judaism.stackexchange.com.json.gz": 32028, "stackexchange_title_body_jsonl/law.stackexchange.com.json.gz": 17941, "stackexchange_title_body_jsonl/magento.stackexchange.com.json.gz": 99991, "stackexchange_title_body_jsonl/math.stackexchange.com.json.gz": 1338443, "stackexchange_title_body_jsonl/mathematica.stackexchange.com.json.gz": 73131, "stackexchange_title_body_jsonl/mathoverflow.net.json.gz": 120851, "stackexchange_title_body_jsonl/mechanics.stackexchange.com.json.gz": 22868, "stackexchange_title_body_jsonl/meta.stackexchange.com.json.gz": 83510, "stackexchange_title_body_jsonl/meta.stackoverflow.com.json.gz": 36456, "stackexchange_title_body_jsonl/money.stackexchange.com.json.gz": 32021, "stackexchange_title_body_jsonl/movies.stackexchange.com.json.gz": 20181, "stackexchange_title_body_jsonl/music.stackexchange.com.json.gz": 20636, "stackexchange_title_body_jsonl/networkengineering.stackexchange.com.json.gz": 13454, "stackexchange_title_body_jsonl/philosophy.stackexchange.com.json.gz": 14829, "stackexchange_title_body_jsonl/photo.stackexchange.com.json.gz": 23753, "stackexchange_title_body_jsonl/physics.stackexchange.com.json.gz": 173307, "stackexchange_title_body_jsonl/politics.stackexchange.com.json.gz": 11894, "stackexchange_title_body_jsonl/puzzling.stackexchange.com.json.gz": 17851, "stackexchange_title_body_jsonl/quant.stackexchange.com.json.gz": 17261, "stackexchange_title_body_jsonl/raspberrypi.stackexchange.com.json.gz": 30625, "stackexchange_title_body_jsonl/rpg.stackexchange.com.json.gz": 42303, "stackexchange_title_body_jsonl/rus.stackexchange.com.json.gz": 16871, "stackexchange_title_body_jsonl/salesforce.stackexchange.com.json.gz": 105260, "stackexchange_title_body_jsonl/scifi.stackexchange.com.json.gz": 61528, "stackexchange_title_body_jsonl/sharepoint.stackexchange.com.json.gz": 94011, "stackexchange_title_body_jsonl/skeptics.stackexchange.com.json.gz": 10009, "stackexchange_title_body_jsonl/small_stackexchanges.json.gz": 448146, "stackexchange_title_body_jsonl/softwareengineering.stackexchange.com.json.gz": 53942, "stackexchange_title_body_jsonl/softwarerecs.stackexchange.com.json.gz": 20142, "stackexchange_title_body_jsonl/space.stackexchange.com.json.gz": 15142, "stackexchange_title_body_jsonl/stackoverflow.com-Posts.json.gz": 18562443, "stackexchange_title_body_jsonl/stats.stackexchange.com.json.gz": 173466, "stackexchange_title_body_jsonl/superuser.com.json.gz": 435463, "stackexchange_title_body_jsonl/tex.stackexchange.com.json.gz": 202954, "stackexchange_title_body_jsonl/travel.stackexchange.com.json.gz": 41227, "stackexchange_title_body_jsonl/unix.stackexchange.com.json.gz": 185997, "stackexchange_title_body_jsonl/ux.stackexchange.com.json.gz": 29403, "stackexchange_title_body_jsonl/vi.stackexchange.com.json.gz": 10551, "stackexchange_title_body_jsonl/webapps.stackexchange.com.json.gz": 29697, "stackexchange_title_body_jsonl/webmasters.stackexchange.com.json.gz": 34559, "stackexchange_title_body_jsonl/wordpress.stackexchange.com.json.gz": 100474, "stackexchange_title_body_jsonl/workplace.stackexchange.com.json.gz": 24189, "stackexchange_title_body_jsonl/worldbuilding.stackexchange.com.json.gz": 26763, "stackexchange_title_body_jsonl/writers.stackexchange.com.json.gz": 10157, "stackexchange_title_body_small.json.gz": 364000, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/3dprinting.stackexchange.com.json.gz": 109, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/academia.stackexchange.com.json.gz": 2465, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ai.stackexchange.com.json.gz": 130, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/android.stackexchange.com.json.gz": 2830, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/anime.stackexchange.com.json.gz": 802, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/apple.stackexchange.com.json.gz": 6696, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/arduino.stackexchange.com.json.gz": 595, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/askubuntu.com.json.gz": 9975, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/astronomy.stackexchange.com.json.gz": 371, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/aviation.stackexchange.com.json.gz": 903, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/avp.stackexchange.com.json.gz": 152, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/beer.stackexchange.com.json.gz": 57, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/bicycles.stackexchange.com.json.gz": 984, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/bioinformatics.stackexchange.com.json.gz": 39, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/biology.stackexchange.com.json.gz": 832, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/bitcoin.stackexchange.com.json.gz": 1068, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/blender.stackexchange.com.json.gz": 1312, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/boardgames.stackexchange.com.json.gz": 691, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/bricks.stackexchange.com.json.gz": 79, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/buddhism.stackexchange.com.json.gz": 770, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cardano.stackexchange.com.json.gz": 7, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/chemistry.stackexchange.com.json.gz": 1523, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/chess.stackexchange.com.json.gz": 402, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/chinese.stackexchange.com.json.gz": 611, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/christianity.stackexchange.com.json.gz": 1502, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/civicrm.stackexchange.com.json.gz": 85, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/codegolf.stackexchange.com.json.gz": 333, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/codereview.stackexchange.com.json.gz": 666, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/coffee.stackexchange.com.json.gz": 47, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cogsci.stackexchange.com.json.gz": 221, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/computergraphics.stackexchange.com.json.gz": 30, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/conlang.stackexchange.com.json.gz": 8, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cooking.stackexchange.com.json.gz": 2064, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/craftcms.stackexchange.com.json.gz": 26, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/crafts.stackexchange.com.json.gz": 72, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/crypto.stackexchange.com.json.gz": 595, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cs.stackexchange.com.json.gz": 936, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cseducators.stackexchange.com.json.gz": 67, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/cstheory.stackexchange.com.json.gz": 314, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/datascience.stackexchange.com.json.gz": 325, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/dba.stackexchange.com.json.gz": 2502, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/devops.stackexchange.com.json.gz": 53, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/diy.stackexchange.com.json.gz": 2037, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/drones.stackexchange.com.json.gz": 6, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/drupal.stackexchange.com.json.gz": 1714, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/dsp.stackexchange.com.json.gz": 387, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/earthscience.stackexchange.com.json.gz": 229, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ebooks.stackexchange.com.json.gz": 54, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/economics.stackexchange.com.json.gz": 441, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/electronics.stackexchange.com.json.gz": 4014, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/elementaryos.stackexchange.com.json.gz": 224, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ell.stackexchange.com.json.gz": 4438, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/emacs.stackexchange.com.json.gz": 188, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/engineering.stackexchange.com.json.gz": 227, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/english.stackexchange.com.json.gz": 13003, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/eosio.stackexchange.com.json.gz": 44, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/esperanto.stackexchange.com.json.gz": 56, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ethereum.stackexchange.com.json.gz": 479, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/expatriates.stackexchange.com.json.gz": 132, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/expressionengine.stackexchange.com.json.gz": 91, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/fitness.stackexchange.com.json.gz": 567, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/freelancing.stackexchange.com.json.gz": 70, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/french.stackexchange.com.json.gz": 632, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/gamedev.stackexchange.com.json.gz": 1598, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/gaming.stackexchange.com.json.gz": 7321, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/gardening.stackexchange.com.json.gz": 210, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/genealogy.stackexchange.com.json.gz": 86, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/german.stackexchange.com.json.gz": 1047, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/gis.stackexchange.com.json.gz": 1843, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/graphicdesign.stackexchange.com.json.gz": 1565, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ham.stackexchange.com.json.gz": 158, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/hardwarerecs.stackexchange.com.json.gz": 58, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/health.stackexchange.com.json.gz": 299, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/hermeneutics.stackexchange.com.json.gz": 1719, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/hinduism.stackexchange.com.json.gz": 343, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/history.stackexchange.com.json.gz": 1099, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/homebrew.stackexchange.com.json.gz": 176, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/hsm.stackexchange.com.json.gz": 70, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/interpersonal.stackexchange.com.json.gz": 469, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/iot.stackexchange.com.json.gz": 10, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/iota.stackexchange.com.json.gz": 31, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/islam.stackexchange.com.json.gz": 2037, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/italian.stackexchange.com.json.gz": 181, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ja.stackoverflow.com.json.gz": 328, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/japanese.stackexchange.com.json.gz": 1124, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/joomla.stackexchange.com.json.gz": 124, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/judaism.stackexchange.com.json.gz": 2216, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/korean.stackexchange.com.json.gz": 28, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/languagelearning.stackexchange.com.json.gz": 42, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/latin.stackexchange.com.json.gz": 55, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/law.stackexchange.com.json.gz": 1297, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/lifehacks.stackexchange.com.json.gz": 316, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/linguistics.stackexchange.com.json.gz": 442, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/literature.stackexchange.com.json.gz": 191, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/magento.stackexchange.com.json.gz": 1849, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/martialarts.stackexchange.com.json.gz": 254, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/materials.stackexchange.com.json.gz": 1, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/matheducators.stackexchange.com.json.gz": 177, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/mathematica.stackexchange.com.json.gz": 262, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/mathoverflow.net.json.gz": 1109, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/mechanics.stackexchange.com.json.gz": 842, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.askubuntu.com.json.gz": 252, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.mathoverflow.net.json.gz": 61, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.serverfault.com.json.gz": 114, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.stackexchange.com.json.gz": 2517, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.stackoverflow.com.json.gz": 2678, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/meta.superuser.com.json.gz": 145, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/moderators.stackexchange.com.json.gz": 23, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/monero.stackexchange.com.json.gz": 26, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/money.stackexchange.com.json.gz": 1905, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/movies.stackexchange.com.json.gz": 1577, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/music.stackexchange.com.json.gz": 1228, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/musicfans.stackexchange.com.json.gz": 78, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/mythology.stackexchange.com.json.gz": 103, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/networkengineering.stackexchange.com.json.gz": 476, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/opendata.stackexchange.com.json.gz": 45, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/opensource.stackexchange.com.json.gz": 123, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/or.stackexchange.com.json.gz": 13, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/outdoors.stackexchange.com.json.gz": 221, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/parenting.stackexchange.com.json.gz": 624, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/patents.stackexchange.com.json.gz": 137, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/pets.stackexchange.com.json.gz": 322, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/philosophy.stackexchange.com.json.gz": 1184, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/photo.stackexchange.com.json.gz": 1432, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/physics.stackexchange.com.json.gz": 8362, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/pm.stackexchange.com.json.gz": 241, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/poker.stackexchange.com.json.gz": 115, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/politics.stackexchange.com.json.gz": 1468, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/portuguese.stackexchange.com.json.gz": 144, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/pt.stackoverflow.com.json.gz": 3718, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/puzzling.stackexchange.com.json.gz": 784, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/quant.stackexchange.com.json.gz": 340, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/quantumcomputing.stackexchange.com.json.gz": 46, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/raspberrypi.stackexchange.com.json.gz": 1011, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/retrocomputing.stackexchange.com.json.gz": 135, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/reverseengineering.stackexchange.com.json.gz": 97, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/robotics.stackexchange.com.json.gz": 110, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/rpg.stackexchange.com.json.gz": 4212, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ru.stackoverflow.com.json.gz": 6305, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/rus.stackexchange.com.json.gz": 514, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/russian.stackexchange.com.json.gz": 353, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/salesforce.stackexchange.com.json.gz": 1781, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/scicomp.stackexchange.com.json.gz": 127, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/scifi.stackexchange.com.json.gz": 5176, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/security.stackexchange.com.json.gz": 3069, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/serverfault.com.json.gz": 7969, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sharepoint.stackexchange.com.json.gz": 1691, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sitecore.stackexchange.com.json.gz": 122, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/skeptics.stackexchange.com.json.gz": 670, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/softwareengineering.stackexchange.com.json.gz": 4238, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/softwarerecs.stackexchange.com.json.gz": 348, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sound.stackexchange.com.json.gz": 365, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/space.stackexchange.com.json.gz": 405, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/spanish.stackexchange.com.json.gz": 366, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sports.stackexchange.com.json.gz": 455, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sqa.stackexchange.com.json.gz": 353, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/stackapps.com.json.gz": 15, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/stats.stackexchange.com.json.gz": 2238, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/stellar.stackexchange.com.json.gz": 3, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/superuser.com.json.gz": 17425, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/sustainability.stackexchange.com.json.gz": 152, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/tex.stackexchange.com.json.gz": 1095, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/tezos.stackexchange.com.json.gz": 11, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/tor.stackexchange.com.json.gz": 137, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/travel.stackexchange.com.json.gz": 1317, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/tridion.stackexchange.com.json.gz": 68, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ukrainian.stackexchange.com.json.gz": 87, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/unix.stackexchange.com.json.gz": 6173, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/ux.stackexchange.com.json.gz": 1107, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/vegetarianism.stackexchange.com.json.gz": 35, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/vi.stackexchange.com.json.gz": 95, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/webapps.stackexchange.com.json.gz": 1906, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/webmasters.stackexchange.com.json.gz": 854, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/windowsphone.stackexchange.com.json.gz": 153, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/woodworking.stackexchange.com.json.gz": 93, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/wordpress.stackexchange.com.json.gz": 3046, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/workplace.stackexchange.com.json.gz": 4317, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/worldbuilding.stackexchange.com.json.gz": 2087, "stackexchange_titlebody_best_and_down_voted_answer_jsonl/writers.stackexchange.com.json.gz": 407, "stackexchange_titlebody_best_voted_answer_jsonl/3dprinting.stackexchange.com.json.gz": 3488, "stackexchange_titlebody_best_voted_answer_jsonl/academia.stackexchange.com.json.gz": 32137, "stackexchange_titlebody_best_voted_answer_jsonl/ai.stackexchange.com.json.gz": 5763, "stackexchange_titlebody_best_voted_answer_jsonl/android.stackexchange.com.json.gz": 38077, "stackexchange_titlebody_best_voted_answer_jsonl/anime.stackexchange.com.json.gz": 10131, "stackexchange_titlebody_best_voted_answer_jsonl/apple.stackexchange.com.json.gz": 92487, "stackexchange_titlebody_best_voted_answer_jsonl/arduino.stackexchange.com.json.gz": 16281, "stackexchange_titlebody_best_voted_answer_jsonl/askubuntu.com.json.gz": 267135, "stackexchange_titlebody_best_voted_answer_jsonl/astronomy.stackexchange.com.json.gz": 9086, "stackexchange_titlebody_best_voted_answer_jsonl/aviation.stackexchange.com.json.gz": 18755, "stackexchange_titlebody_best_voted_answer_jsonl/avp.stackexchange.com.json.gz": 6450, "stackexchange_titlebody_best_voted_answer_jsonl/beer.stackexchange.com.json.gz": 1012, "stackexchange_titlebody_best_voted_answer_jsonl/bicycles.stackexchange.com.json.gz": 15708, "stackexchange_titlebody_best_voted_answer_jsonl/bioinformatics.stackexchange.com.json.gz": 3135, "stackexchange_titlebody_best_voted_answer_jsonl/biology.stackexchange.com.json.gz": 19277, "stackexchange_titlebody_best_voted_answer_jsonl/bitcoin.stackexchange.com.json.gz": 22474, "stackexchange_titlebody_best_voted_answer_jsonl/blender.stackexchange.com.json.gz": 54153, "stackexchange_titlebody_best_voted_answer_jsonl/boardgames.stackexchange.com.json.gz": 11805, "stackexchange_titlebody_best_voted_answer_jsonl/bricks.stackexchange.com.json.gz": 3530, "stackexchange_titlebody_best_voted_answer_jsonl/buddhism.stackexchange.com.json.gz": 6787, "stackexchange_titlebody_best_voted_answer_jsonl/cardano.stackexchange.com.json.gz": 248, "stackexchange_titlebody_best_voted_answer_jsonl/chemistry.stackexchange.com.json.gz": 27061, "stackexchange_titlebody_best_voted_answer_jsonl/chess.stackexchange.com.json.gz": 6392, "stackexchange_titlebody_best_voted_answer_jsonl/chinese.stackexchange.com.json.gz": 8646, "stackexchange_titlebody_best_voted_answer_jsonl/christianity.stackexchange.com.json.gz": 11498, "stackexchange_titlebody_best_voted_answer_jsonl/civicrm.stackexchange.com.json.gz": 10648, "stackexchange_titlebody_best_voted_answer_jsonl/codegolf.stackexchange.com.json.gz": 8211, "stackexchange_titlebody_best_voted_answer_jsonl/codereview.stackexchange.com.json.gz": 41748, "stackexchange_titlebody_best_voted_answer_jsonl/coffee.stackexchange.com.json.gz": 1188, "stackexchange_titlebody_best_voted_answer_jsonl/cogsci.stackexchange.com.json.gz": 5101, "stackexchange_titlebody_best_voted_answer_jsonl/computergraphics.stackexchange.com.json.gz": 2306, "stackexchange_titlebody_best_voted_answer_jsonl/conlang.stackexchange.com.json.gz": 334, "stackexchange_titlebody_best_voted_answer_jsonl/cooking.stackexchange.com.json.gz": 22641, "stackexchange_titlebody_best_voted_answer_jsonl/craftcms.stackexchange.com.json.gz": 11236, "stackexchange_titlebody_best_voted_answer_jsonl/crafts.stackexchange.com.json.gz": 1659, "stackexchange_titlebody_best_voted_answer_jsonl/crypto.stackexchange.com.json.gz": 19404, "stackexchange_titlebody_best_voted_answer_jsonl/cs.stackexchange.com.json.gz": 30010, "stackexchange_titlebody_best_voted_answer_jsonl/cseducators.stackexchange.com.json.gz": 902, "stackexchange_titlebody_best_voted_answer_jsonl/cstheory.stackexchange.com.json.gz": 7742, "stackexchange_titlebody_best_voted_answer_jsonl/datascience.stackexchange.com.json.gz": 20503, "stackexchange_titlebody_best_voted_answer_jsonl/dba.stackexchange.com.json.gz": 71449, "stackexchange_titlebody_best_voted_answer_jsonl/devops.stackexchange.com.json.gz": 3462, "stackexchange_titlebody_best_voted_answer_jsonl/diy.stackexchange.com.json.gz": 52896, "stackexchange_titlebody_best_voted_answer_jsonl/drones.stackexchange.com.json.gz": 496, "stackexchange_titlebody_best_voted_answer_jsonl/drupal.stackexchange.com.json.gz": 67817, "stackexchange_titlebody_best_voted_answer_jsonl/dsp.stackexchange.com.json.gz": 17430, "stackexchange_titlebody_best_voted_answer_jsonl/earthscience.stackexchange.com.json.gz": 4396, "stackexchange_titlebody_best_voted_answer_jsonl/ebooks.stackexchange.com.json.gz": 1107, "stackexchange_titlebody_best_voted_answer_jsonl/economics.stackexchange.com.json.gz": 8844, "stackexchange_titlebody_best_voted_answer_jsonl/electronics.stackexchange.com.json.gz": 129494, "stackexchange_titlebody_best_voted_answer_jsonl/elementaryos.stackexchange.com.json.gz": 5917, "stackexchange_titlebody_best_voted_answer_jsonl/ell.stackexchange.com.json.gz": 77892, "stackexchange_titlebody_best_voted_answer_jsonl/emacs.stackexchange.com.json.gz": 16830, "stackexchange_titlebody_best_voted_answer_jsonl/engineering.stackexchange.com.json.gz": 8649, "stackexchange_titlebody_best_voted_answer_jsonl/english.stackexchange.com.json.gz": 100640, "stackexchange_titlebody_best_voted_answer_jsonl/eosio.stackexchange.com.json.gz": 1940, "stackexchange_titlebody_best_voted_answer_jsonl/esperanto.stackexchange.com.json.gz": 1466, "stackexchange_titlebody_best_voted_answer_jsonl/ethereum.stackexchange.com.json.gz": 26124, "stackexchange_titlebody_best_voted_answer_jsonl/expatriates.stackexchange.com.json.gz": 4913, "stackexchange_titlebody_best_voted_answer_jsonl/expressionengine.stackexchange.com.json.gz": 10742, "stackexchange_titlebody_best_voted_answer_jsonl/fitness.stackexchange.com.json.gz": 8297, "stackexchange_titlebody_best_voted_answer_jsonl/freelancing.stackexchange.com.json.gz": 1663, "stackexchange_titlebody_best_voted_answer_jsonl/french.stackexchange.com.json.gz": 10578, "stackexchange_titlebody_best_voted_answer_jsonl/gamedev.stackexchange.com.json.gz": 40154, "stackexchange_titlebody_best_voted_answer_jsonl/gaming.stackexchange.com.json.gz": 82887, "stackexchange_titlebody_best_voted_answer_jsonl/gardening.stackexchange.com.json.gz": 13246, "stackexchange_titlebody_best_voted_answer_jsonl/genealogy.stackexchange.com.json.gz": 2895, "stackexchange_titlebody_best_voted_answer_jsonl/german.stackexchange.com.json.gz": 13733, "stackexchange_titlebody_best_voted_answer_jsonl/gis.stackexchange.com.json.gz": 100254, "stackexchange_titlebody_best_voted_answer_jsonl/graphicdesign.stackexchange.com.json.gz": 28083, "stackexchange_titlebody_best_voted_answer_jsonl/ham.stackexchange.com.json.gz": 3501, "stackexchange_titlebody_best_voted_answer_jsonl/hardwarerecs.stackexchange.com.json.gz": 2050, "stackexchange_titlebody_best_voted_answer_jsonl/health.stackexchange.com.json.gz": 4494, "stackexchange_titlebody_best_voted_answer_jsonl/hermeneutics.stackexchange.com.json.gz": 9516, "stackexchange_titlebody_best_voted_answer_jsonl/hinduism.stackexchange.com.json.gz": 8999, "stackexchange_titlebody_best_voted_answer_jsonl/history.stackexchange.com.json.gz": 10766, "stackexchange_titlebody_best_voted_answer_jsonl/homebrew.stackexchange.com.json.gz": 5608, "stackexchange_titlebody_best_voted_answer_jsonl/hsm.stackexchange.com.json.gz": 2517, "stackexchange_titlebody_best_voted_answer_jsonl/interpersonal.stackexchange.com.json.gz": 3398, "stackexchange_titlebody_best_voted_answer_jsonl/iot.stackexchange.com.json.gz": 1359, "stackexchange_titlebody_best_voted_answer_jsonl/iota.stackexchange.com.json.gz": 775, "stackexchange_titlebody_best_voted_answer_jsonl/islam.stackexchange.com.json.gz": 10052, "stackexchange_titlebody_best_voted_answer_jsonl/italian.stackexchange.com.json.gz": 3101, "stackexchange_titlebody_best_voted_answer_jsonl/ja.stackoverflow.com.json.gz": 17376, "stackexchange_titlebody_best_voted_answer_jsonl/japanese.stackexchange.com.json.gz": 20948, "stackexchange_titlebody_best_voted_answer_jsonl/joomla.stackexchange.com.json.gz": 5887, "stackexchange_titlebody_best_voted_answer_jsonl/judaism.stackexchange.com.json.gz": 26085, "stackexchange_titlebody_best_voted_answer_jsonl/korean.stackexchange.com.json.gz": 1406, "stackexchange_titlebody_best_voted_answer_jsonl/languagelearning.stackexchange.com.json.gz": 948, "stackexchange_titlebody_best_voted_answer_jsonl/latin.stackexchange.com.json.gz": 3969, "stackexchange_titlebody_best_voted_answer_jsonl/law.stackexchange.com.json.gz": 16133, "stackexchange_titlebody_best_voted_answer_jsonl/lifehacks.stackexchange.com.json.gz": 2576, "stackexchange_titlebody_best_voted_answer_jsonl/linguistics.stackexchange.com.json.gz": 6843, "stackexchange_titlebody_best_voted_answer_jsonl/literature.stackexchange.com.json.gz": 3539, "stackexchange_titlebody_best_voted_answer_jsonl/magento.stackexchange.com.json.gz": 79241, "stackexchange_titlebody_best_voted_answer_jsonl/martialarts.stackexchange.com.json.gz": 1737, "stackexchange_titlebody_best_voted_answer_jsonl/materials.stackexchange.com.json.gz": 1101, "stackexchange_titlebody_best_voted_answer_jsonl/matheducators.stackexchange.com.json.gz": 2706, "stackexchange_titlebody_best_voted_answer_jsonl/mathematica.stackexchange.com.json.gz": 59895, "stackexchange_titlebody_best_voted_answer_jsonl/mathoverflow.net.json.gz": 85289, "stackexchange_titlebody_best_voted_answer_jsonl/mechanics.stackexchange.com.json.gz": 18613, "stackexchange_titlebody_best_voted_answer_jsonl/meta.askubuntu.com.json.gz": 4268, "stackexchange_titlebody_best_voted_answer_jsonl/meta.mathoverflow.net.json.gz": 1000, "stackexchange_titlebody_best_voted_answer_jsonl/meta.serverfault.com.json.gz": 1726, "stackexchange_titlebody_best_voted_answer_jsonl/meta.stackexchange.com.json.gz": 60744, "stackexchange_titlebody_best_voted_answer_jsonl/meta.stackoverflow.com.json.gz": 24044, "stackexchange_titlebody_best_voted_answer_jsonl/meta.superuser.com.json.gz": 3629, "stackexchange_titlebody_best_voted_answer_jsonl/moderators.stackexchange.com.json.gz": 504, "stackexchange_titlebody_best_voted_answer_jsonl/money.stackexchange.com.json.gz": 29404, "stackexchange_titlebody_best_voted_answer_jsonl/movies.stackexchange.com.json.gz": 18243, "stackexchange_titlebody_best_voted_answer_jsonl/music.stackexchange.com.json.gz": 19936, "stackexchange_titlebody_best_voted_answer_jsonl/musicfans.stackexchange.com.json.gz": 2431, "stackexchange_titlebody_best_voted_answer_jsonl/mythology.stackexchange.com.json.gz": 1595, "stackexchange_titlebody_best_voted_answer_jsonl/networkengineering.stackexchange.com.json.gz": 12590, "stackexchange_titlebody_best_voted_answer_jsonl/opendata.stackexchange.com.json.gz": 3842, "stackexchange_titlebody_best_voted_answer_jsonl/opensource.stackexchange.com.json.gz": 3221, "stackexchange_titlebody_best_voted_answer_jsonl/or.stackexchange.com.json.gz": 1490, "stackexchange_titlebody_best_voted_answer_jsonl/outdoors.stackexchange.com.json.gz": 5278, "stackexchange_titlebody_best_voted_answer_jsonl/parenting.stackexchange.com.json.gz": 5998, "stackexchange_titlebody_best_voted_answer_jsonl/patents.stackexchange.com.json.gz": 3573, "stackexchange_titlebody_best_voted_answer_jsonl/pets.stackexchange.com.json.gz": 6156, "stackexchange_titlebody_best_voted_answer_jsonl/philosophy.stackexchange.com.json.gz": 13114, "stackexchange_titlebody_best_voted_answer_jsonl/photo.stackexchange.com.json.gz": 23204, "stackexchange_titlebody_best_voted_answer_jsonl/physics.stackexchange.com.json.gz": 141230, "stackexchange_titlebody_best_voted_answer_jsonl/pm.stackexchange.com.json.gz": 5435, "stackexchange_titlebody_best_voted_answer_jsonl/poker.stackexchange.com.json.gz": 1665, "stackexchange_titlebody_best_voted_answer_jsonl/politics.stackexchange.com.json.gz": 11047, "stackexchange_titlebody_best_voted_answer_jsonl/portuguese.stackexchange.com.json.gz": 1964, "stackexchange_titlebody_best_voted_answer_jsonl/pt.stackoverflow.com.json.gz": 103277, "stackexchange_titlebody_best_voted_answer_jsonl/puzzling.stackexchange.com.json.gz": 17448, "stackexchange_titlebody_best_voted_answer_jsonl/quant.stackexchange.com.json.gz": 12933, "stackexchange_titlebody_best_voted_answer_jsonl/quantumcomputing.stackexchange.com.json.gz": 4320, "stackexchange_titlebody_best_voted_answer_jsonl/raspberrypi.stackexchange.com.json.gz": 24143, "stackexchange_titlebody_best_voted_answer_jsonl/retrocomputing.stackexchange.com.json.gz": 3907, "stackexchange_titlebody_best_voted_answer_jsonl/reverseengineering.stackexchange.com.json.gz": 5817, "stackexchange_titlebody_best_voted_answer_jsonl/robotics.stackexchange.com.json.gz": 4648, "stackexchange_titlebody_best_voted_answer_jsonl/rpg.stackexchange.com.json.gz": 40435, "stackexchange_titlebody_best_voted_answer_jsonl/ru.stackoverflow.com.json.gz": 253289, "stackexchange_titlebody_best_voted_answer_jsonl/rus.stackexchange.com.json.gz": 16528, "stackexchange_titlebody_best_voted_answer_jsonl/russian.stackexchange.com.json.gz": 3937, "stackexchange_titlebody_best_voted_answer_jsonl/salesforce.stackexchange.com.json.gz": 87272, "stackexchange_titlebody_best_voted_answer_jsonl/scicomp.stackexchange.com.json.gz": 7036, "stackexchange_titlebody_best_voted_answer_jsonl/scifi.stackexchange.com.json.gz": 54805, "stackexchange_titlebody_best_voted_answer_jsonl/sharepoint.stackexchange.com.json.gz": 80420, "stackexchange_titlebody_best_voted_answer_jsonl/sitecore.stackexchange.com.json.gz": 7838, "stackexchange_titlebody_best_voted_answer_jsonl/skeptics.stackexchange.com.json.gz": 8145, "stackexchange_titlebody_best_voted_answer_jsonl/softwareengineering.stackexchange.com.json.gz": 51326, "stackexchange_titlebody_best_voted_answer_jsonl/softwarerecs.stackexchange.com.json.gz": 11761, "stackexchange_titlebody_best_voted_answer_jsonl/sound.stackexchange.com.json.gz": 8303, "stackexchange_titlebody_best_voted_answer_jsonl/space.stackexchange.com.json.gz": 12893, "stackexchange_titlebody_best_voted_answer_jsonl/spanish.stackexchange.com.json.gz": 7675, "stackexchange_titlebody_best_voted_answer_jsonl/sports.stackexchange.com.json.gz": 4707, "stackexchange_titlebody_best_voted_answer_jsonl/sqa.stackexchange.com.json.gz": 9256, "stackexchange_titlebody_best_voted_answer_jsonl/stackapps.com.json.gz": 1518, "stackexchange_titlebody_best_voted_answer_jsonl/stats.stackexchange.com.json.gz": 115679, "stackexchange_titlebody_best_voted_answer_jsonl/stellar.stackexchange.com.json.gz": 1078, "stackexchange_titlebody_best_voted_answer_jsonl/superuser.com.json.gz": 352610, "stackexchange_titlebody_best_voted_answer_jsonl/sustainability.stackexchange.com.json.gz": 1674, "stackexchange_titlebody_best_voted_answer_jsonl/tex.stackexchange.com.json.gz": 171628, "stackexchange_titlebody_best_voted_answer_jsonl/tezos.stackexchange.com.json.gz": 1169, "stackexchange_titlebody_best_voted_answer_jsonl/tor.stackexchange.com.json.gz": 4167, "stackexchange_titlebody_best_voted_answer_jsonl/travel.stackexchange.com.json.gz": 36533, "stackexchange_titlebody_best_voted_answer_jsonl/tridion.stackexchange.com.json.gz": 5907, "stackexchange_titlebody_best_voted_answer_jsonl/ukrainian.stackexchange.com.json.gz": 1767, "stackexchange_titlebody_best_voted_answer_jsonl/unix.stackexchange.com.json.gz": 155414, "stackexchange_titlebody_best_voted_answer_jsonl/ux.stackexchange.com.json.gz": 28901, "stackexchange_titlebody_best_voted_answer_jsonl/vegetarianism.stackexchange.com.json.gz": 585, "stackexchange_titlebody_best_voted_answer_jsonl/vi.stackexchange.com.json.gz": 9000, "stackexchange_titlebody_best_voted_answer_jsonl/webapps.stackexchange.com.json.gz": 24867, "stackexchange_titlebody_best_voted_answer_jsonl/webmasters.stackexchange.com.json.gz": 30370, "stackexchange_titlebody_best_voted_answer_jsonl/windowsphone.stackexchange.com.json.gz": 2807, "stackexchange_titlebody_best_voted_answer_jsonl/woodworking.stackexchange.com.json.gz": 2955, "stackexchange_titlebody_best_voted_answer_jsonl/wordpress.stackexchange.com.json.gz": 83621, "stackexchange_titlebody_best_voted_answer_jsonl/workplace.stackexchange.com.json.gz": 24012, "stackexchange_titlebody_best_voted_answer_jsonl/worldbuilding.stackexchange.com.json.gz": 26210, "stackexchange_titlebody_best_voted_answer_jsonl/writers.stackexchange.com.json.gz": 9867, "wikihow.json.gz": 128542, "xsum.json.gz": 226711, "yahoo_answers_question_answer.json.gz": 681164, "yahoo_answers_title_answer.json.gz": 1198260, "yahoo_answers_title_question.json.gz": 659896}
|
flash_attn_triton.py
ADDED
@@ -0,0 +1,1112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 MosaicML Examples authors
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
"""Triton implementation of Flash Attention.
|
5 |
+
|
6 |
+
# Copyright (c) 2022, Tri Dao.
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
|
20 |
+
*Experimental* implementation of FlashAttention in Triton.
|
21 |
+
We use the FlashAttention implementation from Phil Tillet a starting point.
|
22 |
+
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
23 |
+
|
24 |
+
Changes:
|
25 |
+
- Implement both causal and non-causal attention.
|
26 |
+
- Implement both self-attention and cross-attention.
|
27 |
+
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
|
28 |
+
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
|
29 |
+
- Support attention bias.
|
30 |
+
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
|
31 |
+
- Make the backward for d=128 much faster by reducing register spilling.
|
32 |
+
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
33 |
+
small batch size * nheads.
|
34 |
+
|
35 |
+
Caution:
|
36 |
+
- If you plan to use headdim other than 64 and 128, you should test for race conditions
|
37 |
+
(due to the Triton compiler), as done in tests/test_flash_attn.py
|
38 |
+
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
|
39 |
+
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
|
40 |
+
that there are none left for other head dimensions.
|
41 |
+
Differences between this Triton version and the CUDA version:
|
42 |
+
- Triton version doesn't support dropout.
|
43 |
+
- Triton forward is generally faster than CUDA forward.
|
44 |
+
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
|
45 |
+
It is slightly slower when headdim=128 and batch * nheads is large.
|
46 |
+
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
47 |
+
"""
|
48 |
+
|
49 |
+
import math
|
50 |
+
|
51 |
+
import torch
|
52 |
+
import triton # type: ignore (reportMissingImports)
|
53 |
+
import triton.language as tl # type: ignore (reportMissingImports)
|
54 |
+
from einops import repeat
|
55 |
+
|
56 |
+
|
57 |
+
@triton.autotune(
|
58 |
+
configs=[
|
59 |
+
triton.Config({
|
60 |
+
'BLOCK_M': 128,
|
61 |
+
'BLOCK_N': 128
|
62 |
+
},
|
63 |
+
num_warps=8,
|
64 |
+
num_stages=1),
|
65 |
+
# This config has a race condition when EVEN_M == False, disabling it for now.
|
66 |
+
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
|
67 |
+
],
|
68 |
+
key=[
|
69 |
+
'CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL',
|
70 |
+
'BLOCK_HEADDIM'
|
71 |
+
])
|
72 |
+
@triton.heuristics({
|
73 |
+
'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0,
|
74 |
+
'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0,
|
75 |
+
'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM'],
|
76 |
+
})
|
77 |
+
@triton.jit
|
78 |
+
def _fwd_kernel(
|
79 |
+
Q,
|
80 |
+
K,
|
81 |
+
V,
|
82 |
+
Bias,
|
83 |
+
Out,
|
84 |
+
Lse,
|
85 |
+
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
86 |
+
softmax_scale,
|
87 |
+
stride_qb,
|
88 |
+
stride_qh,
|
89 |
+
stride_qm,
|
90 |
+
stride_kb,
|
91 |
+
stride_kh,
|
92 |
+
stride_kn,
|
93 |
+
stride_vb,
|
94 |
+
stride_vh,
|
95 |
+
stride_vn,
|
96 |
+
stride_bb,
|
97 |
+
stride_bh,
|
98 |
+
stride_bm,
|
99 |
+
stride_ob,
|
100 |
+
stride_oh,
|
101 |
+
stride_om,
|
102 |
+
nheads,
|
103 |
+
seqlen_q,
|
104 |
+
seqlen_k,
|
105 |
+
seqlen_q_rounded,
|
106 |
+
headdim,
|
107 |
+
CACHE_KEY_SEQLEN_Q,
|
108 |
+
CACHE_KEY_SEQLEN_K,
|
109 |
+
BIAS_TYPE: tl.constexpr,
|
110 |
+
IS_CAUSAL: tl.constexpr,
|
111 |
+
BLOCK_HEADDIM: tl.constexpr,
|
112 |
+
EVEN_M: tl.constexpr,
|
113 |
+
EVEN_N: tl.constexpr,
|
114 |
+
EVEN_HEADDIM: tl.constexpr,
|
115 |
+
BLOCK_M: tl.constexpr,
|
116 |
+
BLOCK_N: tl.constexpr,
|
117 |
+
):
|
118 |
+
start_m = tl.program_id(0)
|
119 |
+
off_hb = tl.program_id(1)
|
120 |
+
off_b = off_hb // nheads
|
121 |
+
off_h = off_hb % nheads
|
122 |
+
# off_b = tl.program_id(1)
|
123 |
+
# off_h = tl.program_id(2)
|
124 |
+
# off_hb = off_b * nheads + off_h
|
125 |
+
# initialize offsets
|
126 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
127 |
+
offs_n = tl.arange(0, BLOCK_N)
|
128 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
129 |
+
# Initialize pointers to Q, K, V
|
130 |
+
# Adding parenthesis around indexing might use int32 math instead of int64 math?
|
131 |
+
# https://github.com/openai/triton/issues/741
|
132 |
+
# I'm seeing a tiny bit of difference (5-7us)
|
133 |
+
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (
|
134 |
+
offs_m[:, None] * stride_qm + offs_d[None, :])
|
135 |
+
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (
|
136 |
+
offs_n[:, None] * stride_kn + offs_d[None, :])
|
137 |
+
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (
|
138 |
+
offs_n[:, None] * stride_vn + offs_d[None, :])
|
139 |
+
if BIAS_TYPE == 'vector':
|
140 |
+
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
|
141 |
+
elif BIAS_TYPE == 'matrix':
|
142 |
+
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (
|
143 |
+
offs_m[:, None] * stride_bm + offs_n[None, :])
|
144 |
+
else:
|
145 |
+
raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
|
146 |
+
# initialize pointer to m and l
|
147 |
+
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
|
148 |
+
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
|
149 |
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
|
150 |
+
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
151 |
+
# load q: it will stay in SRAM throughout
|
152 |
+
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
|
153 |
+
# tl.load(q_ptrs), we get the wrong output!
|
154 |
+
if EVEN_M & EVEN_N:
|
155 |
+
if EVEN_HEADDIM:
|
156 |
+
q = tl.load(q_ptrs)
|
157 |
+
else:
|
158 |
+
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
159 |
+
else:
|
160 |
+
if EVEN_HEADDIM:
|
161 |
+
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
162 |
+
else:
|
163 |
+
q = tl.load(q_ptrs,
|
164 |
+
mask=(offs_m[:, None] < seqlen_q) &
|
165 |
+
(offs_d[None, :] < headdim),
|
166 |
+
other=0.0)
|
167 |
+
# loop over k, v and update accumulator
|
168 |
+
end_n = seqlen_k if not IS_CAUSAL else tl.minimum(
|
169 |
+
(start_m + 1) * BLOCK_M, seqlen_k)
|
170 |
+
for start_n in range(0, end_n, BLOCK_N):
|
171 |
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
172 |
+
# -- compute qk ----
|
173 |
+
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
174 |
+
if EVEN_HEADDIM:
|
175 |
+
k = tl.load(k_ptrs + start_n * stride_kn)
|
176 |
+
else:
|
177 |
+
k = tl.load(k_ptrs + start_n * stride_kn,
|
178 |
+
mask=offs_d[None, :] < headdim,
|
179 |
+
other=0.0)
|
180 |
+
else:
|
181 |
+
if EVEN_HEADDIM:
|
182 |
+
k = tl.load(k_ptrs + start_n * stride_kn,
|
183 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
184 |
+
other=0.0)
|
185 |
+
else:
|
186 |
+
k = tl.load(k_ptrs + start_n * stride_kn,
|
187 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) &
|
188 |
+
(offs_d[None, :] < headdim),
|
189 |
+
other=0.0)
|
190 |
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
191 |
+
qk += tl.dot(q, k, trans_b=True)
|
192 |
+
# Trying to combine the two masks seem to make the result wrong
|
193 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
194 |
+
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
|
195 |
+
float('-inf'))
|
196 |
+
if IS_CAUSAL:
|
197 |
+
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0,
|
198 |
+
float('-inf'))
|
199 |
+
if BIAS_TYPE != 'none':
|
200 |
+
if BIAS_TYPE == 'vector':
|
201 |
+
if EVEN_N:
|
202 |
+
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
203 |
+
else:
|
204 |
+
bias = tl.load(b_ptrs + start_n,
|
205 |
+
mask=(start_n + offs_n) < seqlen_k,
|
206 |
+
other=0.0).to(tl.float32)
|
207 |
+
bias = bias[None, :]
|
208 |
+
elif BIAS_TYPE == 'matrix':
|
209 |
+
if EVEN_M & EVEN_N:
|
210 |
+
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
211 |
+
else:
|
212 |
+
bias = tl.load(b_ptrs + start_n,
|
213 |
+
mask=(offs_m[:, None] < seqlen_q) &
|
214 |
+
((start_n + offs_n)[None, :] < seqlen_k),
|
215 |
+
other=0.0).to(tl.float32)
|
216 |
+
else:
|
217 |
+
raise ValueError(
|
218 |
+
"BIAS_TYPE must be one of {'vector', 'matrix'}")
|
219 |
+
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
|
220 |
+
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
|
221 |
+
# to multiply with softmax_scale here.
|
222 |
+
qk = qk * softmax_scale + bias
|
223 |
+
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
|
224 |
+
p = tl.exp(qk - m_ij[:, None])
|
225 |
+
else:
|
226 |
+
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
|
227 |
+
p = tl.exp(qk * softmax_scale - m_ij[:, None])
|
228 |
+
l_ij = tl.sum(p, 1)
|
229 |
+
|
230 |
+
# scale acc_o
|
231 |
+
acc_o_scale = tl.exp(m_i - m_ij)
|
232 |
+
|
233 |
+
# # -- update output accumulator --
|
234 |
+
# BUG: have to store and immediately load
|
235 |
+
tl.store(t_ptrs, acc_o_scale)
|
236 |
+
acc_o_scale = tl.load(t_ptrs)
|
237 |
+
acc_o = acc_o * acc_o_scale[:, None]
|
238 |
+
# update acc_o
|
239 |
+
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
240 |
+
if EVEN_HEADDIM:
|
241 |
+
v = tl.load(v_ptrs + start_n * stride_vn)
|
242 |
+
else:
|
243 |
+
v = tl.load(v_ptrs + start_n * stride_vn,
|
244 |
+
mask=offs_d[None, :] < headdim,
|
245 |
+
other=0.0)
|
246 |
+
else:
|
247 |
+
if EVEN_HEADDIM:
|
248 |
+
v = tl.load(v_ptrs + start_n * stride_vn,
|
249 |
+
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
250 |
+
other=0.0)
|
251 |
+
else:
|
252 |
+
v = tl.load(v_ptrs + start_n * stride_vn,
|
253 |
+
mask=((start_n + offs_n)[:, None] < seqlen_k) &
|
254 |
+
(offs_d[None, :] < headdim),
|
255 |
+
other=0.0)
|
256 |
+
p = p.to(v.dtype)
|
257 |
+
acc_o += tl.dot(p, v)
|
258 |
+
|
259 |
+
# -- update statistics
|
260 |
+
m_i = m_ij
|
261 |
+
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
262 |
+
lse_i = m_ij + tl.log(l_i_new)
|
263 |
+
|
264 |
+
o_scale = tl.exp(m_i - lse_i)
|
265 |
+
# BUG: have to store and immediately load
|
266 |
+
tl.store(t_ptrs, o_scale)
|
267 |
+
o_scale = tl.load(t_ptrs)
|
268 |
+
acc_o = acc_o * o_scale[:, None]
|
269 |
+
# rematerialize offsets to save registers
|
270 |
+
start_m = tl.program_id(0)
|
271 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
272 |
+
# write back l and m
|
273 |
+
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
274 |
+
tl.store(lse_ptrs, lse_i)
|
275 |
+
# initialize pointers to output
|
276 |
+
offs_n = tl.arange(0, BLOCK_HEADDIM)
|
277 |
+
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (
|
278 |
+
offs_m[:, None] * stride_om + offs_n[None, :])
|
279 |
+
if EVEN_M:
|
280 |
+
if EVEN_HEADDIM:
|
281 |
+
tl.store(out_ptrs, acc_o)
|
282 |
+
else:
|
283 |
+
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
284 |
+
else:
|
285 |
+
if EVEN_HEADDIM:
|
286 |
+
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
287 |
+
else:
|
288 |
+
tl.store(out_ptrs,
|
289 |
+
acc_o,
|
290 |
+
mask=(offs_m[:, None] < seqlen_q) &
|
291 |
+
(offs_d[None, :] < headdim))
|
292 |
+
|
293 |
+
|
294 |
+
@triton.jit
|
295 |
+
def _bwd_preprocess_do_o_dot(
|
296 |
+
Out,
|
297 |
+
DO,
|
298 |
+
Delta,
|
299 |
+
stride_ob,
|
300 |
+
stride_oh,
|
301 |
+
stride_om,
|
302 |
+
stride_dob,
|
303 |
+
stride_doh,
|
304 |
+
stride_dom,
|
305 |
+
nheads,
|
306 |
+
seqlen_q,
|
307 |
+
seqlen_q_rounded,
|
308 |
+
headdim,
|
309 |
+
BLOCK_M: tl.constexpr,
|
310 |
+
BLOCK_HEADDIM: tl.constexpr,
|
311 |
+
):
|
312 |
+
start_m = tl.program_id(0)
|
313 |
+
off_hb = tl.program_id(1)
|
314 |
+
off_b = off_hb // nheads
|
315 |
+
off_h = off_hb % nheads
|
316 |
+
# initialize offsets
|
317 |
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
318 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
319 |
+
# load
|
320 |
+
o = tl.load(Out + off_b * stride_ob + off_h * stride_oh +
|
321 |
+
offs_m[:, None] * stride_om + offs_d[None, :],
|
322 |
+
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
323 |
+
other=0.0).to(tl.float32)
|
324 |
+
do = tl.load(DO + off_b * stride_dob + off_h * stride_doh +
|
325 |
+
offs_m[:, None] * stride_dom + offs_d[None, :],
|
326 |
+
mask=(offs_m[:, None] < seqlen_q) &
|
327 |
+
(offs_d[None, :] < headdim),
|
328 |
+
other=0.0).to(tl.float32)
|
329 |
+
delta = tl.sum(o * do, axis=1)
|
330 |
+
# write-back
|
331 |
+
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
|
332 |
+
|
333 |
+
|
334 |
+
@triton.jit
|
335 |
+
def _bwd_kernel_one_col_block(
|
336 |
+
start_n,
|
337 |
+
Q,
|
338 |
+
K,
|
339 |
+
V,
|
340 |
+
Bias,
|
341 |
+
DO,
|
342 |
+
DQ,
|
343 |
+
DK,
|
344 |
+
DV,
|
345 |
+
LSE,
|
346 |
+
D,
|
347 |
+
softmax_scale,
|
348 |
+
stride_qm,
|
349 |
+
stride_kn,
|
350 |
+
stride_vn,
|
351 |
+
stride_bm,
|
352 |
+
stride_dom,
|
353 |
+
stride_dqm,
|
354 |
+
stride_dkn,
|
355 |
+
stride_dvn,
|
356 |
+
seqlen_q,
|
357 |
+
seqlen_k,
|
358 |
+
headdim,
|
359 |
+
ATOMIC_ADD: tl.constexpr,
|
360 |
+
BIAS_TYPE: tl.constexpr,
|
361 |
+
IS_CAUSAL: tl.constexpr,
|
362 |
+
BLOCK_HEADDIM: tl.constexpr,
|
363 |
+
EVEN_M: tl.constexpr,
|
364 |
+
EVEN_N: tl.constexpr,
|
365 |
+
EVEN_HEADDIM: tl.constexpr,
|
366 |
+
BLOCK_M: tl.constexpr,
|
367 |
+
BLOCK_N: tl.constexpr,
|
368 |
+
):
|
369 |
+
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
|
370 |
+
begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
|
371 |
+
# initialize row/col offsets
|
372 |
+
offs_qm = begin_m + tl.arange(0, BLOCK_M)
|
373 |
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
374 |
+
offs_m = tl.arange(0, BLOCK_M)
|
375 |
+
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
376 |
+
# initialize pointers to value-like data
|
377 |
+
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
|
378 |
+
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
379 |
+
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
380 |
+
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
|
381 |
+
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
|
382 |
+
if BIAS_TYPE == 'vector':
|
383 |
+
b_ptrs = Bias + offs_n
|
384 |
+
elif BIAS_TYPE == 'matrix':
|
385 |
+
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
|
386 |
+
else:
|
387 |
+
raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
|
388 |
+
# initialize dv and dk
|
389 |
+
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
390 |
+
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
391 |
+
# k and v stay in SRAM throughout
|
392 |
+
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
|
393 |
+
# if we just call tl.load(k_ptrs), we get the wrong output!
|
394 |
+
if EVEN_N & EVEN_M:
|
395 |
+
if EVEN_HEADDIM:
|
396 |
+
k = tl.load(k_ptrs)
|
397 |
+
v = tl.load(v_ptrs)
|
398 |
+
else:
|
399 |
+
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
400 |
+
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
401 |
+
else:
|
402 |
+
if EVEN_HEADDIM:
|
403 |
+
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
404 |
+
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
405 |
+
else:
|
406 |
+
k = tl.load(k_ptrs,
|
407 |
+
mask=(offs_n[:, None] < seqlen_k) &
|
408 |
+
(offs_d[None, :] < headdim),
|
409 |
+
other=0.0)
|
410 |
+
v = tl.load(v_ptrs,
|
411 |
+
mask=(offs_n[:, None] < seqlen_k) &
|
412 |
+
(offs_d[None, :] < headdim),
|
413 |
+
other=0.0)
|
414 |
+
# loop over rows
|
415 |
+
num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
|
416 |
+
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
|
417 |
+
start_m = tl.multiple_of(start_m, BLOCK_M)
|
418 |
+
offs_m_curr = start_m + offs_m
|
419 |
+
# load q, k, v, do on-chip
|
420 |
+
# Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
|
421 |
+
if EVEN_M & EVEN_HEADDIM:
|
422 |
+
q = tl.load(q_ptrs)
|
423 |
+
else:
|
424 |
+
if EVEN_HEADDIM:
|
425 |
+
q = tl.load(q_ptrs,
|
426 |
+
mask=offs_m_curr[:, None] < seqlen_q,
|
427 |
+
other=0.0)
|
428 |
+
else:
|
429 |
+
q = tl.load(q_ptrs,
|
430 |
+
mask=(offs_m_curr[:, None] < seqlen_q) &
|
431 |
+
(offs_d[None, :] < headdim),
|
432 |
+
other=0.0)
|
433 |
+
# recompute p = softmax(qk, dim=-1).T
|
434 |
+
qk = tl.dot(q, k, trans_b=True)
|
435 |
+
# Trying to combine the two masks seem to make the result wrong
|
436 |
+
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
437 |
+
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
|
438 |
+
if IS_CAUSAL:
|
439 |
+
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk,
|
440 |
+
float('-inf'))
|
441 |
+
if BIAS_TYPE != 'none':
|
442 |
+
if BIAS_TYPE == 'vector':
|
443 |
+
if EVEN_N:
|
444 |
+
bias = tl.load(b_ptrs).to(tl.float32)
|
445 |
+
else:
|
446 |
+
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k,
|
447 |
+
other=0.0).to(tl.float32)
|
448 |
+
bias = bias[None, :]
|
449 |
+
elif BIAS_TYPE == 'matrix':
|
450 |
+
if EVEN_M & EVEN_N:
|
451 |
+
bias = tl.load(b_ptrs).to(tl.float32)
|
452 |
+
else:
|
453 |
+
bias = tl.load(b_ptrs,
|
454 |
+
mask=(offs_m_curr[:, None] < seqlen_q) &
|
455 |
+
(offs_n[None, :] < seqlen_k),
|
456 |
+
other=0.0).to(tl.float32)
|
457 |
+
else:
|
458 |
+
raise ValueError(
|
459 |
+
"BIAS_TYPE must be one of {'vector', 'matrix'}")
|
460 |
+
qk = qk * softmax_scale + bias
|
461 |
+
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
|
462 |
+
# Also wrong for headdim=64.
|
463 |
+
if not (EVEN_M & EVEN_HEADDIM):
|
464 |
+
tl.debug_barrier()
|
465 |
+
lse_i = tl.load(LSE + offs_m_curr)
|
466 |
+
if BIAS_TYPE == 'none':
|
467 |
+
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
468 |
+
else:
|
469 |
+
p = tl.exp(qk - lse_i[:, None])
|
470 |
+
# compute dv
|
471 |
+
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
|
472 |
+
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
|
473 |
+
# in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
|
474 |
+
# the output is correct.
|
475 |
+
if EVEN_M & EVEN_HEADDIM:
|
476 |
+
do = tl.load(do_ptrs)
|
477 |
+
else:
|
478 |
+
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
|
479 |
+
do = tl.load(do_ptrs,
|
480 |
+
mask=(offs_m_curr[:, None] < seqlen_q) &
|
481 |
+
(offs_d[None, :] < headdim),
|
482 |
+
other=0.0)
|
483 |
+
# if EVEN_M:
|
484 |
+
# if EVEN_HEADDIM:
|
485 |
+
# do = tl.load(do_ptrs)
|
486 |
+
# else:
|
487 |
+
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
488 |
+
# else:
|
489 |
+
# if EVEN_HEADDIM:
|
490 |
+
# do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
491 |
+
# else:
|
492 |
+
# do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
|
493 |
+
# & (offs_d[None, :] < headdim), other=0.0)
|
494 |
+
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
495 |
+
# compute dp = dot(v, do)
|
496 |
+
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
497 |
+
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
498 |
+
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
|
499 |
+
if not (EVEN_M & EVEN_HEADDIM):
|
500 |
+
tl.debug_barrier()
|
501 |
+
dp = tl.dot(do, v, trans_b=True)
|
502 |
+
# There's a race condition for headdim=48
|
503 |
+
if not EVEN_HEADDIM:
|
504 |
+
tl.debug_barrier()
|
505 |
+
# compute ds = p * (dp - delta[:, None])
|
506 |
+
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
|
507 |
+
Di = tl.load(D + offs_m_curr)
|
508 |
+
# Converting ds to q.dtype here reduces register pressure and makes it much faster
|
509 |
+
# for BLOCK_HEADDIM=128
|
510 |
+
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
|
511 |
+
# compute dk = dot(ds.T, q)
|
512 |
+
dk += tl.dot(ds, q, trans_a=True)
|
513 |
+
# compute dq
|
514 |
+
if not ATOMIC_ADD:
|
515 |
+
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
|
516 |
+
dq = tl.load(dq_ptrs, eviction_policy='evict_last')
|
517 |
+
dq += tl.dot(ds, k)
|
518 |
+
tl.store(dq_ptrs, dq, eviction_policy='evict_last')
|
519 |
+
else:
|
520 |
+
if EVEN_HEADDIM:
|
521 |
+
dq = tl.load(dq_ptrs,
|
522 |
+
mask=offs_m_curr[:, None] < seqlen_q,
|
523 |
+
other=0.0,
|
524 |
+
eviction_policy='evict_last')
|
525 |
+
dq += tl.dot(ds, k)
|
526 |
+
tl.store(dq_ptrs,
|
527 |
+
dq,
|
528 |
+
mask=offs_m_curr[:, None] < seqlen_q,
|
529 |
+
eviction_policy='evict_last')
|
530 |
+
else:
|
531 |
+
dq = tl.load(dq_ptrs,
|
532 |
+
mask=(offs_m_curr[:, None] < seqlen_q) &
|
533 |
+
(offs_d[None, :] < headdim),
|
534 |
+
other=0.0,
|
535 |
+
eviction_policy='evict_last')
|
536 |
+
dq += tl.dot(ds, k)
|
537 |
+
tl.store(dq_ptrs,
|
538 |
+
dq,
|
539 |
+
mask=(offs_m_curr[:, None] < seqlen_q) &
|
540 |
+
(offs_d[None, :] < headdim),
|
541 |
+
eviction_policy='evict_last')
|
542 |
+
else: # If we're parallelizing across the seqlen_k dimension
|
543 |
+
dq = tl.dot(ds, k)
|
544 |
+
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
|
545 |
+
tl.atomic_add(dq_ptrs, dq)
|
546 |
+
else:
|
547 |
+
if EVEN_HEADDIM:
|
548 |
+
tl.atomic_add(dq_ptrs,
|
549 |
+
dq,
|
550 |
+
mask=offs_m_curr[:, None] < seqlen_q)
|
551 |
+
else:
|
552 |
+
tl.atomic_add(dq_ptrs,
|
553 |
+
dq,
|
554 |
+
mask=(offs_m_curr[:, None] < seqlen_q) &
|
555 |
+
(offs_d[None, :] < headdim))
|
556 |
+
# increment pointers
|
557 |
+
dq_ptrs += BLOCK_M * stride_dqm
|
558 |
+
q_ptrs += BLOCK_M * stride_qm
|
559 |
+
do_ptrs += BLOCK_M * stride_dom
|
560 |
+
if BIAS_TYPE == 'matrix':
|
561 |
+
b_ptrs += BLOCK_M * stride_bm
|
562 |
+
# write-back
|
563 |
+
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
564 |
+
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
565 |
+
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
|
566 |
+
# if we just call tl.store(dv_ptrs), there's a race condition
|
567 |
+
if EVEN_N & EVEN_M:
|
568 |
+
if EVEN_HEADDIM:
|
569 |
+
tl.store(dv_ptrs, dv)
|
570 |
+
tl.store(dk_ptrs, dk)
|
571 |
+
else:
|
572 |
+
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
573 |
+
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
574 |
+
else:
|
575 |
+
if EVEN_HEADDIM:
|
576 |
+
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
577 |
+
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
578 |
+
else:
|
579 |
+
tl.store(dv_ptrs,
|
580 |
+
dv,
|
581 |
+
mask=(offs_n[:, None] < seqlen_k) &
|
582 |
+
(offs_d[None, :] < headdim))
|
583 |
+
tl.store(dk_ptrs,
|
584 |
+
dk,
|
585 |
+
mask=(offs_n[:, None] < seqlen_k) &
|
586 |
+
(offs_d[None, :] < headdim))
|
587 |
+
|
588 |
+
|
589 |
+
def init_to_zero(name):
|
590 |
+
return lambda nargs: nargs[name].zero_()
|
591 |
+
|
592 |
+
|
593 |
+
@triton.autotune(
|
594 |
+
configs=[
|
595 |
+
triton.Config(
|
596 |
+
{
|
597 |
+
'BLOCK_M': 128,
|
598 |
+
'BLOCK_N': 128,
|
599 |
+
'SEQUENCE_PARALLEL': False
|
600 |
+
},
|
601 |
+
num_warps=8,
|
602 |
+
num_stages=1,
|
603 |
+
pre_hook=init_to_zero('DQ')),
|
604 |
+
triton.Config(
|
605 |
+
{
|
606 |
+
'BLOCK_M': 128,
|
607 |
+
'BLOCK_N': 128,
|
608 |
+
'SEQUENCE_PARALLEL': True
|
609 |
+
},
|
610 |
+
num_warps=8,
|
611 |
+
num_stages=1,
|
612 |
+
pre_hook=init_to_zero('DQ')),
|
613 |
+
# Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
|
614 |
+
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
|
615 |
+
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
616 |
+
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
617 |
+
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
|
618 |
+
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
|
619 |
+
],
|
620 |
+
key=[
|
621 |
+
'CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL',
|
622 |
+
'BLOCK_HEADDIM'
|
623 |
+
],
|
624 |
+
)
|
625 |
+
@triton.heuristics({
|
626 |
+
'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0,
|
627 |
+
'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0,
|
628 |
+
'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM'],
|
629 |
+
})
|
630 |
+
@triton.jit
|
631 |
+
def _bwd_kernel(
|
632 |
+
Q,
|
633 |
+
K,
|
634 |
+
V,
|
635 |
+
Bias,
|
636 |
+
DO,
|
637 |
+
DQ,
|
638 |
+
DK,
|
639 |
+
DV,
|
640 |
+
LSE,
|
641 |
+
D,
|
642 |
+
softmax_scale,
|
643 |
+
stride_qb,
|
644 |
+
stride_qh,
|
645 |
+
stride_qm,
|
646 |
+
stride_kb,
|
647 |
+
stride_kh,
|
648 |
+
stride_kn,
|
649 |
+
stride_vb,
|
650 |
+
stride_vh,
|
651 |
+
stride_vn,
|
652 |
+
stride_bb,
|
653 |
+
stride_bh,
|
654 |
+
stride_bm,
|
655 |
+
stride_dob,
|
656 |
+
stride_doh,
|
657 |
+
stride_dom,
|
658 |
+
stride_dqb,
|
659 |
+
stride_dqh,
|
660 |
+
stride_dqm,
|
661 |
+
stride_dkb,
|
662 |
+
stride_dkh,
|
663 |
+
stride_dkn,
|
664 |
+
stride_dvb,
|
665 |
+
stride_dvh,
|
666 |
+
stride_dvn,
|
667 |
+
nheads,
|
668 |
+
seqlen_q,
|
669 |
+
seqlen_k,
|
670 |
+
seqlen_q_rounded,
|
671 |
+
headdim,
|
672 |
+
CACHE_KEY_SEQLEN_Q,
|
673 |
+
CACHE_KEY_SEQLEN_K,
|
674 |
+
BIAS_TYPE: tl.constexpr,
|
675 |
+
IS_CAUSAL: tl.constexpr,
|
676 |
+
BLOCK_HEADDIM: tl.constexpr,
|
677 |
+
SEQUENCE_PARALLEL: tl.constexpr,
|
678 |
+
EVEN_M: tl.constexpr,
|
679 |
+
EVEN_N: tl.constexpr,
|
680 |
+
EVEN_HEADDIM: tl.constexpr,
|
681 |
+
BLOCK_M: tl.constexpr,
|
682 |
+
BLOCK_N: tl.constexpr,
|
683 |
+
):
|
684 |
+
off_hb = tl.program_id(1)
|
685 |
+
off_b = off_hb // nheads
|
686 |
+
off_h = off_hb % nheads
|
687 |
+
# offset pointers for batch/head
|
688 |
+
Q += off_b * stride_qb + off_h * stride_qh
|
689 |
+
K += off_b * stride_kb + off_h * stride_kh
|
690 |
+
V += off_b * stride_vb + off_h * stride_vh
|
691 |
+
DO += off_b * stride_dob + off_h * stride_doh
|
692 |
+
DQ += off_b * stride_dqb + off_h * stride_dqh
|
693 |
+
DK += off_b * stride_dkb + off_h * stride_dkh
|
694 |
+
DV += off_b * stride_dvb + off_h * stride_dvh
|
695 |
+
if BIAS_TYPE != 'none':
|
696 |
+
Bias += off_b * stride_bb + off_h * stride_bh
|
697 |
+
# pointer to row-wise quantities in value-like data
|
698 |
+
D += off_hb * seqlen_q_rounded
|
699 |
+
LSE += off_hb * seqlen_q_rounded
|
700 |
+
if not SEQUENCE_PARALLEL:
|
701 |
+
num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
|
702 |
+
for start_n in range(0, num_block_n):
|
703 |
+
_bwd_kernel_one_col_block(start_n,
|
704 |
+
Q,
|
705 |
+
K,
|
706 |
+
V,
|
707 |
+
Bias,
|
708 |
+
DO,
|
709 |
+
DQ,
|
710 |
+
DK,
|
711 |
+
DV,
|
712 |
+
LSE,
|
713 |
+
D,
|
714 |
+
softmax_scale,
|
715 |
+
stride_qm,
|
716 |
+
stride_kn,
|
717 |
+
stride_vn,
|
718 |
+
stride_bm,
|
719 |
+
stride_dom,
|
720 |
+
stride_dqm,
|
721 |
+
stride_dkn,
|
722 |
+
stride_dvn,
|
723 |
+
seqlen_q,
|
724 |
+
seqlen_k,
|
725 |
+
headdim,
|
726 |
+
ATOMIC_ADD=False,
|
727 |
+
BIAS_TYPE=BIAS_TYPE,
|
728 |
+
IS_CAUSAL=IS_CAUSAL,
|
729 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
730 |
+
EVEN_M=EVEN_M,
|
731 |
+
EVEN_N=EVEN_N,
|
732 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
733 |
+
BLOCK_M=BLOCK_M,
|
734 |
+
BLOCK_N=BLOCK_N)
|
735 |
+
else:
|
736 |
+
start_n = tl.program_id(0)
|
737 |
+
_bwd_kernel_one_col_block(start_n,
|
738 |
+
Q,
|
739 |
+
K,
|
740 |
+
V,
|
741 |
+
Bias,
|
742 |
+
DO,
|
743 |
+
DQ,
|
744 |
+
DK,
|
745 |
+
DV,
|
746 |
+
LSE,
|
747 |
+
D,
|
748 |
+
softmax_scale,
|
749 |
+
stride_qm,
|
750 |
+
stride_kn,
|
751 |
+
stride_vn,
|
752 |
+
stride_bm,
|
753 |
+
stride_dom,
|
754 |
+
stride_dqm,
|
755 |
+
stride_dkn,
|
756 |
+
stride_dvn,
|
757 |
+
seqlen_q,
|
758 |
+
seqlen_k,
|
759 |
+
headdim,
|
760 |
+
ATOMIC_ADD=True,
|
761 |
+
BIAS_TYPE=BIAS_TYPE,
|
762 |
+
IS_CAUSAL=IS_CAUSAL,
|
763 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
764 |
+
EVEN_M=EVEN_M,
|
765 |
+
EVEN_N=EVEN_N,
|
766 |
+
EVEN_HEADDIM=EVEN_HEADDIM,
|
767 |
+
BLOCK_M=BLOCK_M,
|
768 |
+
BLOCK_N=BLOCK_N)
|
769 |
+
|
770 |
+
|
771 |
+
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
772 |
+
# shape constraints
|
773 |
+
batch, seqlen_q, nheads, d = q.shape
|
774 |
+
_, seqlen_k, _, _ = k.shape
|
775 |
+
assert k.shape == (batch, seqlen_k, nheads, d)
|
776 |
+
assert v.shape == (batch, seqlen_k, nheads, d)
|
777 |
+
assert d <= 128, 'FlashAttention only support head dimensions up to 128'
|
778 |
+
assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
|
779 |
+
assert q.dtype in [torch.float16,
|
780 |
+
torch.bfloat16], 'Only support fp16 and bf16'
|
781 |
+
assert q.is_cuda and k.is_cuda and v.is_cuda
|
782 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
783 |
+
|
784 |
+
has_bias = bias is not None
|
785 |
+
bias_type = 'none'
|
786 |
+
if has_bias:
|
787 |
+
assert bias.dtype in [q.dtype, torch.float]
|
788 |
+
assert bias.is_cuda
|
789 |
+
assert bias.dim() == 4
|
790 |
+
if bias.stride(-1) != 1:
|
791 |
+
bias = bias.contiguous()
|
792 |
+
if bias.shape[2:] == (1, seqlen_k):
|
793 |
+
bias_type = 'vector'
|
794 |
+
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
795 |
+
bias_type = 'matrix'
|
796 |
+
else:
|
797 |
+
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
798 |
+
' or (seqlen_q, seqlen_k)')
|
799 |
+
if bias.shape[:2] == (1, nheads):
|
800 |
+
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
|
801 |
+
elif bias.shape[:2] == (batch, 1):
|
802 |
+
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
|
803 |
+
elif bias.shape[:2] == (1, 1):
|
804 |
+
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
|
805 |
+
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
|
806 |
+
assert bias.shape[:2] == (
|
807 |
+
batch, nheads
|
808 |
+
), f'First 2 dimensions of bias must be broadcastible to (batch, nheads) = ({batch, nheads}). Bias has shape: {bias.shape}'
|
809 |
+
assert bias is not None # for type checking
|
810 |
+
bias_strides = (bias.stride(0), bias.stride(1),
|
811 |
+
bias.stride(2)) if has_bias else (0, 0, 0)
|
812 |
+
|
813 |
+
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
814 |
+
lse = torch.empty((batch, nheads, seqlen_q_rounded),
|
815 |
+
device=q.device,
|
816 |
+
dtype=torch.float32)
|
817 |
+
tmp = torch.empty((batch, nheads, seqlen_q_rounded),
|
818 |
+
device=q.device,
|
819 |
+
dtype=torch.float32)
|
820 |
+
o = torch.empty_like(q)
|
821 |
+
|
822 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
823 |
+
# BLOCK = 128
|
824 |
+
# num_warps = 4 if d <= 64 else 8
|
825 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
|
826 |
+
_fwd_kernel[grid]( # type: ignore
|
827 |
+
q,
|
828 |
+
k,
|
829 |
+
v,
|
830 |
+
bias,
|
831 |
+
o,
|
832 |
+
lse,
|
833 |
+
tmp,
|
834 |
+
softmax_scale,
|
835 |
+
q.stride(0),
|
836 |
+
q.stride(2),
|
837 |
+
q.stride(1),
|
838 |
+
k.stride(0),
|
839 |
+
k.stride(2),
|
840 |
+
k.stride(1),
|
841 |
+
v.stride(0),
|
842 |
+
v.stride(2),
|
843 |
+
v.stride(1),
|
844 |
+
*bias_strides,
|
845 |
+
o.stride(0),
|
846 |
+
o.stride(2),
|
847 |
+
o.stride(1),
|
848 |
+
nheads,
|
849 |
+
seqlen_q,
|
850 |
+
seqlen_k,
|
851 |
+
seqlen_q_rounded,
|
852 |
+
d,
|
853 |
+
seqlen_q // 32,
|
854 |
+
seqlen_k // 32, # key for triton cache (limit number of compilations)
|
855 |
+
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
|
856 |
+
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
|
857 |
+
bias_type,
|
858 |
+
causal,
|
859 |
+
BLOCK_HEADDIM,
|
860 |
+
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
861 |
+
# num_warps=num_warps,
|
862 |
+
# num_stages=1,
|
863 |
+
)
|
864 |
+
return o, lse, softmax_scale # softmax_scale could have been updated
|
865 |
+
|
866 |
+
|
867 |
+
def _flash_attn_backward(do,
|
868 |
+
q,
|
869 |
+
k,
|
870 |
+
v,
|
871 |
+
o,
|
872 |
+
lse,
|
873 |
+
dq,
|
874 |
+
dk,
|
875 |
+
dv,
|
876 |
+
bias=None,
|
877 |
+
causal=False,
|
878 |
+
softmax_scale=None):
|
879 |
+
# Make sure that the last dimension is contiguous
|
880 |
+
if do.stride(-1) != 1:
|
881 |
+
do = do.contiguous()
|
882 |
+
batch, seqlen_q, nheads, d = q.shape
|
883 |
+
_, seqlen_k, _, _ = k.shape
|
884 |
+
# assert d in {16, 32, 64, 128}
|
885 |
+
assert d <= 128
|
886 |
+
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
887 |
+
assert lse.shape == (batch, nheads, seqlen_q_rounded)
|
888 |
+
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
|
889 |
+
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
|
890 |
+
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
891 |
+
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
|
892 |
+
dq_accum = torch.empty_like(q, dtype=torch.float32)
|
893 |
+
delta = torch.empty_like(lse)
|
894 |
+
# delta = torch.zeros_like(lse)
|
895 |
+
|
896 |
+
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
897 |
+
grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
|
898 |
+
_bwd_preprocess_do_o_dot[grid]( # type: ignore
|
899 |
+
o,
|
900 |
+
do,
|
901 |
+
delta,
|
902 |
+
o.stride(0),
|
903 |
+
o.stride(2),
|
904 |
+
o.stride(1),
|
905 |
+
do.stride(0),
|
906 |
+
do.stride(2),
|
907 |
+
do.stride(1),
|
908 |
+
nheads,
|
909 |
+
seqlen_q,
|
910 |
+
seqlen_q_rounded,
|
911 |
+
d,
|
912 |
+
BLOCK_M=128,
|
913 |
+
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
914 |
+
)
|
915 |
+
|
916 |
+
has_bias = bias is not None
|
917 |
+
bias_type = 'none'
|
918 |
+
if has_bias:
|
919 |
+
assert bias.dtype in [q.dtype, torch.float]
|
920 |
+
assert bias.is_cuda
|
921 |
+
assert bias.dim() == 4
|
922 |
+
assert bias.stride(-1) == 1
|
923 |
+
if bias.shape[2:] == (1, seqlen_k):
|
924 |
+
bias_type = 'vector'
|
925 |
+
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
926 |
+
bias_type = 'matrix'
|
927 |
+
else:
|
928 |
+
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
929 |
+
' or (seqlen_q, seqlen_k)')
|
930 |
+
if bias.shape[:2] == (1, nheads):
|
931 |
+
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
|
932 |
+
elif bias.shape[:2] == (batch, 1):
|
933 |
+
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
|
934 |
+
elif bias.shape[:2] == (1, 1):
|
935 |
+
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
|
936 |
+
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
|
937 |
+
assert bias.shape[:2] == (
|
938 |
+
batch, nheads
|
939 |
+
), f'First 2 dimensions of bias must be broadcastible to (batch, nheads) = ({batch, nheads}). Bias has shape: {bias.shape}'
|
940 |
+
assert bias is not None # type checking
|
941 |
+
bias_strides = (bias.stride(0), bias.stride(1),
|
942 |
+
bias.stride(2)) if has_bias else (0, 0, 0)
|
943 |
+
|
944 |
+
# BLOCK_M = 128
|
945 |
+
# BLOCK_N = 64
|
946 |
+
# num_warps = 4
|
947 |
+
grid = lambda META: (triton.cdiv(seqlen_k, META['BLOCK_N'])
|
948 |
+
if META['SEQUENCE_PARALLEL'] else 1, batch * nheads)
|
949 |
+
_bwd_kernel[grid]( # type: ignore
|
950 |
+
q,
|
951 |
+
k,
|
952 |
+
v,
|
953 |
+
bias,
|
954 |
+
do,
|
955 |
+
dq_accum,
|
956 |
+
dk,
|
957 |
+
dv,
|
958 |
+
lse,
|
959 |
+
delta,
|
960 |
+
softmax_scale,
|
961 |
+
q.stride(0),
|
962 |
+
q.stride(2),
|
963 |
+
q.stride(1),
|
964 |
+
k.stride(0),
|
965 |
+
k.stride(2),
|
966 |
+
k.stride(1),
|
967 |
+
v.stride(0),
|
968 |
+
v.stride(2),
|
969 |
+
v.stride(1),
|
970 |
+
*bias_strides,
|
971 |
+
do.stride(0),
|
972 |
+
do.stride(2),
|
973 |
+
do.stride(1),
|
974 |
+
dq_accum.stride(0),
|
975 |
+
dq_accum.stride(2),
|
976 |
+
dq_accum.stride(1),
|
977 |
+
dk.stride(0),
|
978 |
+
dk.stride(2),
|
979 |
+
dk.stride(1),
|
980 |
+
dv.stride(0),
|
981 |
+
dv.stride(2),
|
982 |
+
dv.stride(1),
|
983 |
+
nheads,
|
984 |
+
seqlen_q,
|
985 |
+
seqlen_k,
|
986 |
+
seqlen_q_rounded,
|
987 |
+
d,
|
988 |
+
seqlen_q // 32,
|
989 |
+
seqlen_k // 32, # key for triton cache (limit number of compilations)
|
990 |
+
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
|
991 |
+
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
|
992 |
+
bias_type,
|
993 |
+
causal,
|
994 |
+
BLOCK_HEADDIM,
|
995 |
+
# SEQUENCE_PARALLEL=False,
|
996 |
+
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
997 |
+
# num_warps=num_warps,
|
998 |
+
# num_stages=1,
|
999 |
+
)
|
1000 |
+
dq.copy_(dq_accum)
|
1001 |
+
|
1002 |
+
|
1003 |
+
class _FlashAttnQKVPackedFunc(torch.autograd.Function):
|
1004 |
+
|
1005 |
+
@staticmethod
|
1006 |
+
def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
|
1007 |
+
"""Forward pass for packed FlashAttention.
|
1008 |
+
|
1009 |
+
Args:
|
1010 |
+
ctx: autograd context
|
1011 |
+
qkv: (batch, seqlen, 3, nheads, headdim)
|
1012 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
|
1013 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
|
1014 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
|
1015 |
+
causal (bool): whether to incorporate causal attention masking
|
1016 |
+
softmax_scale (float, optional): scale factor for softmax
|
1017 |
+
"""
|
1018 |
+
# Make sure that the last dimension is contiguous
|
1019 |
+
if qkv.stride(-1) != 1:
|
1020 |
+
qkv = qkv.contiguous()
|
1021 |
+
o, lse, ctx.softmax_scale = _flash_attn_forward(
|
1022 |
+
qkv[:, :, 0],
|
1023 |
+
qkv[:, :, 1],
|
1024 |
+
qkv[:, :, 2],
|
1025 |
+
bias=bias,
|
1026 |
+
causal=causal,
|
1027 |
+
softmax_scale=softmax_scale)
|
1028 |
+
ctx.save_for_backward(qkv, o, lse, bias)
|
1029 |
+
ctx.causal = causal
|
1030 |
+
return o
|
1031 |
+
|
1032 |
+
@staticmethod
|
1033 |
+
def backward(ctx, do):
|
1034 |
+
qkv, o, lse, bias = ctx.saved_tensors
|
1035 |
+
assert not ctx.needs_input_grad[
|
1036 |
+
1], 'FlashAttention does not support bias gradient yet'
|
1037 |
+
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
1038 |
+
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
1039 |
+
with torch.inference_mode():
|
1040 |
+
dqkv = torch.empty_like(qkv)
|
1041 |
+
_flash_attn_backward(do,
|
1042 |
+
qkv[:, :, 0],
|
1043 |
+
qkv[:, :, 1],
|
1044 |
+
qkv[:, :, 2],
|
1045 |
+
o,
|
1046 |
+
lse,
|
1047 |
+
dqkv[:, :, 0],
|
1048 |
+
dqkv[:, :, 1],
|
1049 |
+
dqkv[:, :, 2],
|
1050 |
+
bias=bias,
|
1051 |
+
causal=ctx.causal,
|
1052 |
+
softmax_scale=ctx.softmax_scale)
|
1053 |
+
return dqkv, None, None, None
|
1054 |
+
|
1055 |
+
|
1056 |
+
flash_attn_qkvpacked_func = _FlashAttnQKVPackedFunc.apply
|
1057 |
+
|
1058 |
+
|
1059 |
+
class _FlashAttnFunc(torch.autograd.Function):
|
1060 |
+
|
1061 |
+
@staticmethod
|
1062 |
+
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
|
1063 |
+
"""Forward pass for FlashAttention.
|
1064 |
+
|
1065 |
+
Args:
|
1066 |
+
ctx: autograd context
|
1067 |
+
q: (batch_size, seqlen_q, nheads, headdim)
|
1068 |
+
k: (batch_size, seqlen_k, nheads, headdim)
|
1069 |
+
v: (batch_size, seqlen_k, nheads, headdim)
|
1070 |
+
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
1071 |
+
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
1072 |
+
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
1073 |
+
causal (bool): whether to incorporate causal attention masking
|
1074 |
+
softmax_scale (float, optional): scale factor for softmax
|
1075 |
+
"""
|
1076 |
+
# Make sure that the last dimension is contiguous
|
1077 |
+
q, k, v = [
|
1078 |
+
x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]
|
1079 |
+
]
|
1080 |
+
o, lse, ctx.softmax_scale = _flash_attn_forward(
|
1081 |
+
q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
|
1082 |
+
ctx.save_for_backward(q, k, v, o, lse, bias)
|
1083 |
+
ctx.causal = causal
|
1084 |
+
return o
|
1085 |
+
|
1086 |
+
@staticmethod
|
1087 |
+
def backward(ctx, do):
|
1088 |
+
q, k, v, o, lse, bias = ctx.saved_tensors
|
1089 |
+
assert not ctx.needs_input_grad[
|
1090 |
+
3], 'FlashAttention does not support bias gradient yet'
|
1091 |
+
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
1092 |
+
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
1093 |
+
with torch.inference_mode():
|
1094 |
+
dq = torch.empty_like(q)
|
1095 |
+
dk = torch.empty_like(k)
|
1096 |
+
dv = torch.empty_like(v)
|
1097 |
+
_flash_attn_backward(do,
|
1098 |
+
q,
|
1099 |
+
k,
|
1100 |
+
v,
|
1101 |
+
o,
|
1102 |
+
lse,
|
1103 |
+
dq,
|
1104 |
+
dk,
|
1105 |
+
dv,
|
1106 |
+
bias=bias,
|
1107 |
+
causal=ctx.causal,
|
1108 |
+
softmax_scale=ctx.softmax_scale)
|
1109 |
+
return dq, dk, dv, None, None, None
|
1110 |
+
|
1111 |
+
|
1112 |
+
flash_attn_func = _FlashAttnFunc.apply
|
modules.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"idx": 2,
|
16 |
+
"name": "2",
|
17 |
+
"path": "2_Normalize",
|
18 |
+
"type": "sentence_transformers.models.Normalize"
|
19 |
+
}
|
20 |
+
]
|
mteb_results/AmazonCounterfactualClassification.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "e8379541af4e31359cca9fbcf4b00f2671dba205",
|
3 |
+
"mteb_dataset_name": "AmazonCounterfactualClassification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"en": {
|
7 |
+
"accuracy": 0.697313432835821,
|
8 |
+
"accuracy_stderr": 0.04363113167902916,
|
9 |
+
"ap": 0.31618259511417734,
|
10 |
+
"ap_stderr": 0.0243939127481388,
|
11 |
+
"f1": 0.6330313825394228,
|
12 |
+
"f1_stderr": 0.03331211721747352,
|
13 |
+
"main_score": 0.697313432835821
|
14 |
+
},
|
15 |
+
"evaluation_time": 3.55
|
16 |
+
},
|
17 |
+
"validation": {
|
18 |
+
"en": {
|
19 |
+
"accuracy": 0.7074626865671642,
|
20 |
+
"accuracy_stderr": 0.03173177854547658,
|
21 |
+
"ap": 0.2916547890175021,
|
22 |
+
"ap_stderr": 0.028577509879931906,
|
23 |
+
"f1": 0.628207439570022,
|
24 |
+
"f1_stderr": 0.02728677964172927,
|
25 |
+
"main_score": 0.7074626865671642
|
26 |
+
},
|
27 |
+
"evaluation_time": 7.2
|
28 |
+
}
|
29 |
+
}
|
mteb_results/AmazonPolarityClassification.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "e2d317d38cd51312af73b3d32a06d1a08b442046",
|
3 |
+
"mteb_dataset_name": "AmazonPolarityClassification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"accuracy": 0.8689837499999999,
|
7 |
+
"accuracy_stderr": 0.010742354621427285,
|
8 |
+
"ap": 0.8239500885672127,
|
9 |
+
"ap_stderr": 0.013236818266475252,
|
10 |
+
"evaluation_time": 1082.95,
|
11 |
+
"f1": 0.8687317947399658,
|
12 |
+
"f1_stderr": 0.011035411217540664,
|
13 |
+
"main_score": 0.8689837499999999
|
14 |
+
}
|
15 |
+
}
|
mteb_results/AmazonReviewsClassification.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "1399c76144fd37290681b995c656ef9b2e06e26d",
|
3 |
+
"mteb_dataset_name": "AmazonReviewsClassification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"en": {
|
7 |
+
"accuracy": 0.44049999999999995,
|
8 |
+
"accuracy_stderr": 0.014423938435808711,
|
9 |
+
"f1": 0.4267624383248947,
|
10 |
+
"f1_stderr": 0.01351683620968048,
|
11 |
+
"main_score": 0.44049999999999995
|
12 |
+
},
|
13 |
+
"evaluation_time": 11.38
|
14 |
+
},
|
15 |
+
"validation": {
|
16 |
+
"en": {
|
17 |
+
"accuracy": 0.43798000000000004,
|
18 |
+
"accuracy_stderr": 0.012288352208494032,
|
19 |
+
"f1": 0.42483998553432956,
|
20 |
+
"f1_stderr": 0.015752944478543963,
|
21 |
+
"main_score": 0.43798000000000004
|
22 |
+
},
|
23 |
+
"evaluation_time": 13.69
|
24 |
+
}
|
25 |
+
}
|
mteb_results/ArguAna.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "ArguAna",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 52.53,
|
7 |
+
"map_at_1": 0.26174,
|
8 |
+
"map_at_10": 0.40976,
|
9 |
+
"map_at_100": 0.42067,
|
10 |
+
"map_at_1000": 0.42075,
|
11 |
+
"map_at_3": 0.35917,
|
12 |
+
"map_at_5": 0.38656,
|
13 |
+
"mrr_at_1": 0.26814,
|
14 |
+
"mrr_at_10": 0.41252,
|
15 |
+
"mrr_at_100": 0.42337,
|
16 |
+
"mrr_at_1000": 0.42345,
|
17 |
+
"mrr_at_3": 0.36226,
|
18 |
+
"mrr_at_5": 0.38914,
|
19 |
+
"ndcg_at_1": 0.26174,
|
20 |
+
"ndcg_at_10": 0.49819,
|
21 |
+
"ndcg_at_100": 0.54404,
|
22 |
+
"ndcg_at_1000": 0.5459,
|
23 |
+
"ndcg_at_3": 0.39231,
|
24 |
+
"ndcg_at_5": 0.44189,
|
25 |
+
"precision_at_1": 0.26174,
|
26 |
+
"precision_at_10": 0.07838,
|
27 |
+
"precision_at_100": 0.00982,
|
28 |
+
"precision_at_1000": 0.001,
|
29 |
+
"precision_at_3": 0.16287,
|
30 |
+
"precision_at_5": 0.12191,
|
31 |
+
"recall_at_1": 0.26174,
|
32 |
+
"recall_at_10": 0.78378,
|
33 |
+
"recall_at_100": 0.98222,
|
34 |
+
"recall_at_1000": 0.99644,
|
35 |
+
"recall_at_3": 0.48862,
|
36 |
+
"recall_at_5": 0.60953
|
37 |
+
}
|
38 |
+
}
|
mteb_results/ArxivClusteringP2P.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "a122ad7f3f0291bf49cc6f4d32aa80929df69d5d",
|
3 |
+
"mteb_dataset_name": "ArxivClusteringP2P",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 4034.06,
|
7 |
+
"v_measure": 0.4231689035788179,
|
8 |
+
"v_measure_std": 0.1399577095144373
|
9 |
+
}
|
10 |
+
}
|
mteb_results/ArxivClusteringS2S.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "f910caf1a6075f7329cdf8c1a6135696f37dbd53",
|
3 |
+
"mteb_dataset_name": "ArxivClusteringS2S",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 432.48,
|
7 |
+
"v_measure": 0.31280245136660983,
|
8 |
+
"v_measure_std": 0.14616358182910433
|
9 |
+
}
|
10 |
+
}
|
mteb_results/AskUbuntuDupQuestions.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "2000358ca161889fa9c082cb41daa8dcfb161a54",
|
3 |
+
"mteb_dataset_name": "AskUbuntuDupQuestions",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 4.26,
|
7 |
+
"map": 0.5879109720839415,
|
8 |
+
"mrr": 0.7179615705931495
|
9 |
+
}
|
10 |
+
}
|
mteb_results/BIOSSES.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "d3fb88f8f02e40887cd149695127462bbcf29b4a",
|
3 |
+
"mteb_dataset_name": "BIOSSES",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"cos_sim": {
|
7 |
+
"pearson": 0.7644918756608116,
|
8 |
+
"spearman": 0.7086607256286257
|
9 |
+
},
|
10 |
+
"euclidean": {
|
11 |
+
"pearson": 0.7412154678100815,
|
12 |
+
"spearman": 0.7086607256286257
|
13 |
+
},
|
14 |
+
"evaluation_time": 1.08,
|
15 |
+
"manhattan": {
|
16 |
+
"pearson": 0.7400786269644171,
|
17 |
+
"spearman": 0.7068353828321327
|
18 |
+
}
|
19 |
+
}
|
20 |
+
}
|
mteb_results/Banking77Classification.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "0fd18e25b25c072e09e0d92ab615fda904d66300",
|
3 |
+
"mteb_dataset_name": "Banking77Classification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"accuracy": 0.7540584415584415,
|
7 |
+
"accuracy_stderr": 0.007828985179390284,
|
8 |
+
"evaluation_time": 20.37,
|
9 |
+
"f1": 0.7429514617572676,
|
10 |
+
"f1_stderr": 0.00868929710762345,
|
11 |
+
"main_score": 0.7540584415584415
|
12 |
+
}
|
13 |
+
}
|
mteb_results/BiorxivClusteringP2P.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "65b79d1d13f80053f67aca9498d9402c2d9f1f40",
|
3 |
+
"mteb_dataset_name": "BiorxivClusteringP2P",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 547.26,
|
7 |
+
"v_measure": 0.3741860080664014,
|
8 |
+
"v_measure_std": 0.008407780040443218
|
9 |
+
}
|
10 |
+
}
|
mteb_results/BiorxivClusteringS2S.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "258694dd0231531bc1fd9de6ceb52a0853c6d908",
|
3 |
+
"mteb_dataset_name": "BiorxivClusteringS2S",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 53.57,
|
7 |
+
"v_measure": 0.29319217023090705,
|
8 |
+
"v_measure_std": 0.010219281239166302
|
9 |
+
}
|
10 |
+
}
|
mteb_results/CQADupstackEnglishRetrieval.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "CQADupstackEnglishRetrieval",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 113.5,
|
7 |
+
"map_at_1": 0.22528,
|
8 |
+
"map_at_10": 0.30751,
|
9 |
+
"map_at_100": 0.31855,
|
10 |
+
"map_at_1000": 0.31972,
|
11 |
+
"map_at_3": 0.28465,
|
12 |
+
"map_at_5": 0.29738,
|
13 |
+
"mrr_at_1": 0.28662,
|
14 |
+
"mrr_at_10": 0.35912,
|
15 |
+
"mrr_at_100": 0.36726,
|
16 |
+
"mrr_at_1000": 0.36777,
|
17 |
+
"mrr_at_3": 0.34013,
|
18 |
+
"mrr_at_5": 0.35156,
|
19 |
+
"ndcg_at_1": 0.28662,
|
20 |
+
"ndcg_at_10": 0.35452,
|
21 |
+
"ndcg_at_100": 0.401,
|
22 |
+
"ndcg_at_1000": 0.42323,
|
23 |
+
"ndcg_at_3": 0.32112,
|
24 |
+
"ndcg_at_5": 0.33638,
|
25 |
+
"precision_at_1": 0.28662,
|
26 |
+
"precision_at_10": 0.06688,
|
27 |
+
"precision_at_100": 0.0113,
|
28 |
+
"precision_at_1000": 0.0016,
|
29 |
+
"precision_at_3": 0.15563,
|
30 |
+
"precision_at_5": 0.11019,
|
31 |
+
"recall_at_1": 0.22528,
|
32 |
+
"recall_at_10": 0.43748,
|
33 |
+
"recall_at_100": 0.64235,
|
34 |
+
"recall_at_1000": 0.78609,
|
35 |
+
"recall_at_3": 0.33937,
|
36 |
+
"recall_at_5": 0.38234
|
37 |
+
}
|
38 |
+
}
|
mteb_results/ClimateFEVER.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "ClimateFEVER",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 8671.85,
|
7 |
+
"map_at_1": 0.09468,
|
8 |
+
"map_at_10": 0.16029,
|
9 |
+
"map_at_100": 0.17693,
|
10 |
+
"map_at_1000": 0.17886,
|
11 |
+
"map_at_3": 0.1315,
|
12 |
+
"map_at_5": 0.14568,
|
13 |
+
"mrr_at_1": 0.21173,
|
14 |
+
"mrr_at_10": 0.31028,
|
15 |
+
"mrr_at_100": 0.32061,
|
16 |
+
"mrr_at_1000": 0.32119,
|
17 |
+
"mrr_at_3": 0.27535,
|
18 |
+
"mrr_at_5": 0.29431,
|
19 |
+
"ndcg_at_1": 0.21173,
|
20 |
+
"ndcg_at_10": 0.23224,
|
21 |
+
"ndcg_at_100": 0.30225,
|
22 |
+
"ndcg_at_1000": 0.33961,
|
23 |
+
"ndcg_at_3": 0.18174,
|
24 |
+
"ndcg_at_5": 0.19897,
|
25 |
+
"precision_at_1": 0.21173,
|
26 |
+
"precision_at_10": 0.07472,
|
27 |
+
"precision_at_100": 0.01501,
|
28 |
+
"precision_at_1000": 0.00219,
|
29 |
+
"precision_at_3": 0.13312,
|
30 |
+
"precision_at_5": 0.10619,
|
31 |
+
"recall_at_1": 0.09468,
|
32 |
+
"recall_at_10": 0.28823,
|
33 |
+
"recall_at_100": 0.53265,
|
34 |
+
"recall_at_1000": 0.74536,
|
35 |
+
"recall_at_3": 0.16672,
|
36 |
+
"recall_at_5": 0.21302
|
37 |
+
}
|
38 |
+
}
|
mteb_results/DBPedia.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "DBPedia",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 4445.99,
|
7 |
+
"map_at_1": 0.06343,
|
8 |
+
"map_at_10": 0.12717,
|
9 |
+
"map_at_100": 0.1648,
|
10 |
+
"map_at_1000": 0.17381,
|
11 |
+
"map_at_3": 0.09569,
|
12 |
+
"map_at_5": 0.11125,
|
13 |
+
"mrr_at_1": 0.4875,
|
14 |
+
"mrr_at_10": 0.58425,
|
15 |
+
"mrr_at_100": 0.59075,
|
16 |
+
"mrr_at_1000": 0.59095,
|
17 |
+
"mrr_at_3": 0.56292,
|
18 |
+
"mrr_at_5": 0.57679,
|
19 |
+
"ndcg_at_1": 0.37875,
|
20 |
+
"ndcg_at_10": 0.2777,
|
21 |
+
"ndcg_at_100": 0.30289,
|
22 |
+
"ndcg_at_1000": 0.36188,
|
23 |
+
"ndcg_at_3": 0.31386,
|
24 |
+
"ndcg_at_5": 0.29923,
|
25 |
+
"precision_at_1": 0.4875,
|
26 |
+
"precision_at_10": 0.22375,
|
27 |
+
"precision_at_100": 0.06342,
|
28 |
+
"precision_at_1000": 0.01449,
|
29 |
+
"precision_at_3": 0.355,
|
30 |
+
"precision_at_5": 0.3055,
|
31 |
+
"recall_at_1": 0.06343,
|
32 |
+
"recall_at_10": 0.16936,
|
33 |
+
"recall_at_100": 0.35956,
|
34 |
+
"recall_at_1000": 0.55787,
|
35 |
+
"recall_at_3": 0.10771,
|
36 |
+
"recall_at_5": 0.1367
|
37 |
+
}
|
38 |
+
}
|
mteb_results/EmotionClassification.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "4f58c6b202a23cf9a4da393831edf4f9183cad37",
|
3 |
+
"mteb_dataset_name": "EmotionClassification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"accuracy": 0.4199,
|
7 |
+
"accuracy_stderr": 0.02234367919569201,
|
8 |
+
"evaluation_time": 3.37,
|
9 |
+
"f1": 0.3682340217456495,
|
10 |
+
"f1_stderr": 0.021776128234136445,
|
11 |
+
"main_score": 0.4199
|
12 |
+
},
|
13 |
+
"validation": {
|
14 |
+
"accuracy": 0.41864999999999997,
|
15 |
+
"accuracy_stderr": 0.022959801828413062,
|
16 |
+
"evaluation_time": 3.29,
|
17 |
+
"f1": 0.3748604511300154,
|
18 |
+
"f1_stderr": 0.02042335727335004,
|
19 |
+
"main_score": 0.41864999999999997
|
20 |
+
}
|
21 |
+
}
|
mteb_results/FEVER.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "FEVER",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 16698.45,
|
7 |
+
"map_at_1": 0.40088,
|
8 |
+
"map_at_10": 0.52692,
|
9 |
+
"map_at_100": 0.53296,
|
10 |
+
"map_at_1000": 0.53325,
|
11 |
+
"map_at_3": 0.49905,
|
12 |
+
"map_at_5": 0.51617,
|
13 |
+
"mrr_at_1": 0.43009,
|
14 |
+
"mrr_at_10": 0.56203,
|
15 |
+
"mrr_at_100": 0.5675,
|
16 |
+
"mrr_at_1000": 0.56769,
|
17 |
+
"mrr_at_3": 0.534,
|
18 |
+
"mrr_at_5": 0.55163,
|
19 |
+
"ndcg_at_1": 0.43009,
|
20 |
+
"ndcg_at_10": 0.5939,
|
21 |
+
"ndcg_at_100": 0.6213,
|
22 |
+
"ndcg_at_1000": 0.62793,
|
23 |
+
"ndcg_at_3": 0.53878,
|
24 |
+
"ndcg_at_5": 0.56887,
|
25 |
+
"precision_at_1": 0.43009,
|
26 |
+
"precision_at_10": 0.08366,
|
27 |
+
"precision_at_100": 0.00983,
|
28 |
+
"precision_at_1000": 0.00105,
|
29 |
+
"precision_at_3": 0.22377,
|
30 |
+
"precision_at_5": 0.15035,
|
31 |
+
"recall_at_1": 0.40088,
|
32 |
+
"recall_at_10": 0.76687,
|
33 |
+
"recall_at_100": 0.8891,
|
34 |
+
"recall_at_1000": 0.93782,
|
35 |
+
"recall_at_3": 0.6181,
|
36 |
+
"recall_at_5": 0.69131
|
37 |
+
}
|
38 |
+
}
|
mteb_results/FiQA2018.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "FiQA2018",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 138.39,
|
7 |
+
"map_at_1": 0.10817,
|
8 |
+
"map_at_10": 0.189,
|
9 |
+
"map_at_100": 0.20448,
|
10 |
+
"map_at_1000": 0.20661,
|
11 |
+
"map_at_3": 0.15979,
|
12 |
+
"map_at_5": 0.17415,
|
13 |
+
"mrr_at_1": 0.23148,
|
14 |
+
"mrr_at_10": 0.31208,
|
15 |
+
"mrr_at_100": 0.32167,
|
16 |
+
"mrr_at_1000": 0.32242,
|
17 |
+
"mrr_at_3": 0.28498,
|
18 |
+
"mrr_at_5": 0.29964,
|
19 |
+
"ndcg_at_1": 0.23148,
|
20 |
+
"ndcg_at_10": 0.25326,
|
21 |
+
"ndcg_at_100": 0.31927,
|
22 |
+
"ndcg_at_1000": 0.36081,
|
23 |
+
"ndcg_at_3": 0.21647,
|
24 |
+
"ndcg_at_5": 0.22763,
|
25 |
+
"precision_at_1": 0.23148,
|
26 |
+
"precision_at_10": 0.07546,
|
27 |
+
"precision_at_100": 0.01415,
|
28 |
+
"precision_at_1000": 0.00216,
|
29 |
+
"precision_at_3": 0.14969,
|
30 |
+
"precision_at_5": 0.11327,
|
31 |
+
"recall_at_1": 0.10817,
|
32 |
+
"recall_at_10": 0.32164,
|
33 |
+
"recall_at_100": 0.57655,
|
34 |
+
"recall_at_1000": 0.82797,
|
35 |
+
"recall_at_3": 0.19709,
|
36 |
+
"recall_at_5": 0.24333
|
37 |
+
}
|
38 |
+
}
|
mteb_results/HotpotQA.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "HotpotQA",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 5192.2,
|
7 |
+
"map_at_1": 0.25381,
|
8 |
+
"map_at_10": 0.3314,
|
9 |
+
"map_at_100": 0.33948,
|
10 |
+
"map_at_1000": 0.34028,
|
11 |
+
"map_at_3": 0.3102,
|
12 |
+
"map_at_5": 0.3223,
|
13 |
+
"mrr_at_1": 0.50763,
|
14 |
+
"mrr_at_10": 0.57899,
|
15 |
+
"mrr_at_100": 0.58426,
|
16 |
+
"mrr_at_1000": 0.58457,
|
17 |
+
"mrr_at_3": 0.56093,
|
18 |
+
"mrr_at_5": 0.57116,
|
19 |
+
"ndcg_at_1": 0.50763,
|
20 |
+
"ndcg_at_10": 0.41656,
|
21 |
+
"ndcg_at_100": 0.45079,
|
22 |
+
"ndcg_at_1000": 0.46917,
|
23 |
+
"ndcg_at_3": 0.37834,
|
24 |
+
"ndcg_at_5": 0.39732,
|
25 |
+
"precision_at_1": 0.50763,
|
26 |
+
"precision_at_10": 0.08648,
|
27 |
+
"precision_at_100": 0.01135,
|
28 |
+
"precision_at_1000": 0.00138,
|
29 |
+
"precision_at_3": 0.23106,
|
30 |
+
"precision_at_5": 0.15363,
|
31 |
+
"recall_at_1": 0.25381,
|
32 |
+
"recall_at_10": 0.43241,
|
33 |
+
"recall_at_100": 0.56745,
|
34 |
+
"recall_at_1000": 0.69048,
|
35 |
+
"recall_at_3": 0.34659,
|
36 |
+
"recall_at_5": 0.38406
|
37 |
+
}
|
38 |
+
}
|
mteb_results/ImdbClassification.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "3d86128a09e091d6018b6d26cad27f2739fc2db7",
|
3 |
+
"mteb_dataset_name": "ImdbClassification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"accuracy": 0.79544,
|
7 |
+
"accuracy_stderr": 0.022193916283522398,
|
8 |
+
"ap": 0.7382920133396664,
|
9 |
+
"ap_stderr": 0.029776228173533717,
|
10 |
+
"evaluation_time": 205.12,
|
11 |
+
"f1": 0.7951048124883265,
|
12 |
+
"f1_stderr": 0.02219958939576688,
|
13 |
+
"main_score": 0.79544
|
14 |
+
}
|
15 |
+
}
|
mteb_results/MSMARCO.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"dev": {
|
4 |
+
"evaluation_time": 19626.59,
|
5 |
+
"map_at_1": 0.11174,
|
6 |
+
"map_at_10": 0.19452,
|
7 |
+
"map_at_100": 0.20612,
|
8 |
+
"map_at_1000": 0.20703,
|
9 |
+
"map_at_3": 0.16444,
|
10 |
+
"map_at_5": 0.18083,
|
11 |
+
"mrr_at_1": 0.11447,
|
12 |
+
"mrr_at_10": 0.19808,
|
13 |
+
"mrr_at_100": 0.20958,
|
14 |
+
"mrr_at_1000": 0.21042,
|
15 |
+
"mrr_at_3": 0.16791,
|
16 |
+
"mrr_at_5": 0.18459,
|
17 |
+
"ndcg_at_1": 0.11447,
|
18 |
+
"ndcg_at_10": 0.24556,
|
19 |
+
"ndcg_at_100": 0.30638,
|
20 |
+
"ndcg_at_1000": 0.3314,
|
21 |
+
"ndcg_at_3": 0.18325,
|
22 |
+
"ndcg_at_5": 0.21278,
|
23 |
+
"precision_at_1": 0.11447,
|
24 |
+
"precision_at_10": 0.04215,
|
25 |
+
"precision_at_100": 0.00732,
|
26 |
+
"precision_at_1000": 0.00095,
|
27 |
+
"precision_at_3": 0.08052,
|
28 |
+
"precision_at_5": 0.06318,
|
29 |
+
"recall_at_1": 0.11174,
|
30 |
+
"recall_at_10": 0.40543,
|
31 |
+
"recall_at_100": 0.69699,
|
32 |
+
"recall_at_1000": 0.89403,
|
33 |
+
"recall_at_3": 0.23442,
|
34 |
+
"recall_at_5": 0.30536
|
35 |
+
},
|
36 |
+
"mteb_dataset_name": "MSMARCO",
|
37 |
+
"mteb_version": "1.1.0"
|
38 |
+
}
|
mteb_results/MTOPDomainClassification.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "d80d48c1eb48d3562165c59d59d0034df9fff0bf",
|
3 |
+
"mteb_dataset_name": "MTOPDomainClassification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"en": {
|
7 |
+
"accuracy": 0.8966712266301871,
|
8 |
+
"accuracy_stderr": 0.009523011920085962,
|
9 |
+
"f1": 0.8957660424361247,
|
10 |
+
"f1_stderr": 0.009247170021662966,
|
11 |
+
"main_score": 0.8966712266301871
|
12 |
+
},
|
13 |
+
"evaluation_time": 7.75
|
14 |
+
},
|
15 |
+
"validation": {
|
16 |
+
"en": {
|
17 |
+
"accuracy": 0.9017002237136464,
|
18 |
+
"accuracy_stderr": 0.009890167527403295,
|
19 |
+
"f1": 0.9039792204701363,
|
20 |
+
"f1_stderr": 0.009182351003334687,
|
21 |
+
"main_score": 0.9017002237136464
|
22 |
+
},
|
23 |
+
"evaluation_time": 4.88
|
24 |
+
}
|
25 |
+
}
|
mteb_results/MTOPIntentClassification.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "ae001d0e6b1228650b7bd1c2c65fb50ad11a8aba",
|
3 |
+
"mteb_dataset_name": "MTOPIntentClassification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"en": {
|
7 |
+
"accuracy": 0.6028499772001825,
|
8 |
+
"accuracy_stderr": 0.018495543127366038,
|
9 |
+
"f1": 0.40306374001528233,
|
10 |
+
"f1_stderr": 0.011859407815520086,
|
11 |
+
"main_score": 0.6028499772001825
|
12 |
+
},
|
13 |
+
"evaluation_time": 30.96
|
14 |
+
},
|
15 |
+
"validation": {
|
16 |
+
"en": {
|
17 |
+
"accuracy": 0.6150335570469799,
|
18 |
+
"accuracy_stderr": 0.01903139236025276,
|
19 |
+
"f1": 0.4147129810603558,
|
20 |
+
"f1_stderr": 0.015560901035463594,
|
21 |
+
"main_score": 0.6150335570469799
|
22 |
+
},
|
23 |
+
"evaluation_time": 28.07
|
24 |
+
}
|
25 |
+
}
|
mteb_results/MassiveIntentClassification.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "31efe3c427b0bae9c22cbb560b8f15491cc6bed7",
|
3 |
+
"mteb_dataset_name": "MassiveIntentClassification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"en": {
|
7 |
+
"accuracy": 0.6333557498318763,
|
8 |
+
"accuracy_stderr": 0.014612806300514952,
|
9 |
+
"f1": 0.6024039910680179,
|
10 |
+
"f1_stderr": 0.012256367770368185,
|
11 |
+
"main_score": 0.6333557498318763
|
12 |
+
},
|
13 |
+
"evaluation_time": 22.52
|
14 |
+
},
|
15 |
+
"validation": {
|
16 |
+
"en": {
|
17 |
+
"accuracy": 0.6426955238563699,
|
18 |
+
"accuracy_stderr": 0.01633350887848132,
|
19 |
+
"f1": 0.5828069832892886,
|
20 |
+
"f1_stderr": 0.013604921852646317,
|
21 |
+
"main_score": 0.6426955238563699
|
22 |
+
},
|
23 |
+
"evaluation_time": 17.6
|
24 |
+
}
|
25 |
+
}
|
mteb_results/MassiveScenarioClassification.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "7d571f92784cd94a019292a1f45445077d0ef634",
|
3 |
+
"mteb_dataset_name": "MassiveScenarioClassification",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"en": {
|
7 |
+
"accuracy": 0.7237390719569603,
|
8 |
+
"accuracy_stderr": 0.006043481355389665,
|
9 |
+
"f1": 0.7233097333477316,
|
10 |
+
"f1_stderr": 0.0075559844507943974,
|
11 |
+
"main_score": 0.7237390719569603
|
12 |
+
},
|
13 |
+
"evaluation_time": 6.86
|
14 |
+
},
|
15 |
+
"validation": {
|
16 |
+
"en": {
|
17 |
+
"accuracy": 0.7321200196753566,
|
18 |
+
"accuracy_stderr": 0.010745609148770754,
|
19 |
+
"f1": 0.7288011677053199,
|
20 |
+
"f1_stderr": 0.010826173990376636,
|
21 |
+
"main_score": 0.7321200196753566
|
22 |
+
},
|
23 |
+
"evaluation_time": 5.61
|
24 |
+
}
|
25 |
+
}
|
mteb_results/MedrxivClusteringP2P.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "e7a26af6f3ae46b30dde8737f02c07b1505bcc73",
|
3 |
+
"mteb_dataset_name": "MedrxivClusteringP2P",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 226.37,
|
7 |
+
"v_measure": 0.34681589390605516,
|
8 |
+
"v_measure_std": 0.01515645822647098
|
9 |
+
}
|
10 |
+
}
|
mteb_results/MedrxivClusteringS2S.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "35191c8c0dca72d8ff3efcd72aa802307d469663",
|
3 |
+
"mteb_dataset_name": "MedrxivClusteringS2S",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 31.13,
|
7 |
+
"v_measure": 0.30340061711905236,
|
8 |
+
"v_measure_std": 0.012579424998938571
|
9 |
+
}
|
10 |
+
}
|
mteb_results/MindSmallReranking.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "3bdac13927fdc888b903db93b2ffdbd90b295a69",
|
3 |
+
"mteb_dataset_name": "MindSmallReranking",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 1849.79,
|
7 |
+
"map": 0.32018143262958026,
|
8 |
+
"mrr": 0.33205552400553673
|
9 |
+
}
|
10 |
+
}
|
mteb_results/NFCorpus.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "NFCorpus",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 17.47,
|
7 |
+
"map_at_1": 0.03391,
|
8 |
+
"map_at_10": 0.07722,
|
9 |
+
"map_at_100": 0.10286,
|
10 |
+
"map_at_1000": 0.11668,
|
11 |
+
"map_at_3": 0.05552,
|
12 |
+
"map_at_5": 0.06468,
|
13 |
+
"mrr_at_1": 0.34365,
|
14 |
+
"mrr_at_10": 0.42555,
|
15 |
+
"mrr_at_100": 0.43295,
|
16 |
+
"mrr_at_1000": 0.43357,
|
17 |
+
"mrr_at_3": 0.40299,
|
18 |
+
"mrr_at_5": 0.41182,
|
19 |
+
"ndcg_at_1": 0.31424,
|
20 |
+
"ndcg_at_10": 0.24758,
|
21 |
+
"ndcg_at_100": 0.23678,
|
22 |
+
"ndcg_at_1000": 0.33377,
|
23 |
+
"ndcg_at_3": 0.28302,
|
24 |
+
"ndcg_at_5": 0.26342,
|
25 |
+
"precision_at_1": 0.33437,
|
26 |
+
"precision_at_10": 0.19257,
|
27 |
+
"precision_at_100": 0.06663,
|
28 |
+
"precision_at_1000": 0.0199,
|
29 |
+
"precision_at_3": 0.27761,
|
30 |
+
"precision_at_5": 0.23715,
|
31 |
+
"recall_at_1": 0.03391,
|
32 |
+
"recall_at_10": 0.11068,
|
33 |
+
"recall_at_100": 0.25878,
|
34 |
+
"recall_at_1000": 0.6019,
|
35 |
+
"recall_at_3": 0.06169,
|
36 |
+
"recall_at_5": 0.07767
|
37 |
+
}
|
38 |
+
}
|
mteb_results/NQ.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "NQ",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 7686.43,
|
7 |
+
"map_at_1": 0.15168,
|
8 |
+
"map_at_10": 0.26177,
|
9 |
+
"map_at_100": 0.27564,
|
10 |
+
"map_at_1000": 0.27629,
|
11 |
+
"map_at_3": 0.2203,
|
12 |
+
"map_at_5": 0.24276,
|
13 |
+
"mrr_at_1": 0.17439,
|
14 |
+
"mrr_at_10": 0.28205,
|
15 |
+
"mrr_at_100": 0.29357,
|
16 |
+
"mrr_at_1000": 0.29408,
|
17 |
+
"mrr_at_3": 0.24377,
|
18 |
+
"mrr_at_5": 0.2654,
|
19 |
+
"ndcg_at_1": 0.1741,
|
20 |
+
"ndcg_at_10": 0.32936,
|
21 |
+
"ndcg_at_100": 0.39197,
|
22 |
+
"ndcg_at_1000": 0.40892,
|
23 |
+
"ndcg_at_3": 0.24721,
|
24 |
+
"ndcg_at_5": 0.28615,
|
25 |
+
"precision_at_1": 0.1741,
|
26 |
+
"precision_at_10": 0.06199,
|
27 |
+
"precision_at_100": 0.00969,
|
28 |
+
"precision_at_1000": 0.00113,
|
29 |
+
"precision_at_3": 0.1179,
|
30 |
+
"precision_at_5": 0.09264,
|
31 |
+
"recall_at_1": 0.15168,
|
32 |
+
"recall_at_10": 0.51914,
|
33 |
+
"recall_at_100": 0.79804,
|
34 |
+
"recall_at_1000": 0.9276,
|
35 |
+
"recall_at_3": 0.30212,
|
36 |
+
"recall_at_5": 0.39204
|
37 |
+
}
|
38 |
+
}
|
mteb_results/QuoraRetrieval.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "QuoraRetrieval",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 258.94,
|
7 |
+
"map_at_1": 0.67306,
|
8 |
+
"map_at_10": 0.80634,
|
9 |
+
"map_at_100": 0.81349,
|
10 |
+
"map_at_1000": 0.81373,
|
11 |
+
"map_at_3": 0.77691,
|
12 |
+
"map_at_5": 0.79512,
|
13 |
+
"mrr_at_1": 0.7756,
|
14 |
+
"mrr_at_10": 0.84177,
|
15 |
+
"mrr_at_100": 0.8435,
|
16 |
+
"mrr_at_1000": 0.84353,
|
17 |
+
"mrr_at_3": 0.83003,
|
18 |
+
"mrr_at_5": 0.83799,
|
19 |
+
"ndcg_at_1": 0.7758,
|
20 |
+
"ndcg_at_10": 0.84782,
|
21 |
+
"ndcg_at_100": 0.86443,
|
22 |
+
"ndcg_at_1000": 0.86654,
|
23 |
+
"ndcg_at_3": 0.8167,
|
24 |
+
"ndcg_at_5": 0.83356,
|
25 |
+
"precision_at_1": 0.7758,
|
26 |
+
"precision_at_10": 0.12875,
|
27 |
+
"precision_at_100": 0.01503,
|
28 |
+
"precision_at_1000": 0.00156,
|
29 |
+
"precision_at_3": 0.3563,
|
30 |
+
"precision_at_5": 0.23484,
|
31 |
+
"recall_at_1": 0.67306,
|
32 |
+
"recall_at_10": 0.9264,
|
33 |
+
"recall_at_100": 0.98681,
|
34 |
+
"recall_at_1000": 0.9979,
|
35 |
+
"recall_at_3": 0.83682,
|
36 |
+
"recall_at_5": 0.88424
|
37 |
+
}
|
38 |
+
}
|
mteb_results/RedditClustering.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "24640382cdbf8abc73003fb0fa6d111a705499eb",
|
3 |
+
"mteb_dataset_name": "RedditClustering",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 237.53,
|
7 |
+
"v_measure": 0.5076319866126382,
|
8 |
+
"v_measure_std": 0.04676162821389071
|
9 |
+
}
|
10 |
+
}
|
mteb_results/RedditClusteringP2P.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "282350215ef01743dc01b456c7f5241fa8937f16",
|
3 |
+
"mteb_dataset_name": "RedditClusteringP2P",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 1202.29,
|
7 |
+
"v_measure": 0.55024711941649,
|
8 |
+
"v_measure_std": 0.12775990781233748
|
9 |
+
}
|
10 |
+
}
|
mteb_results/SCIDOCS.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": null,
|
3 |
+
"mteb_dataset_name": "SCIDOCS",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"evaluation_time": 82.79,
|
7 |
+
"map_at_1": 0.03938,
|
8 |
+
"map_at_10": 0.08817,
|
9 |
+
"map_at_100": 0.10547,
|
10 |
+
"map_at_1000": 0.10852,
|
11 |
+
"map_at_3": 0.06352,
|
12 |
+
"map_at_5": 0.07453,
|
13 |
+
"mrr_at_1": 0.194,
|
14 |
+
"mrr_at_10": 0.27371,
|
15 |
+
"mrr_at_100": 0.28672,
|
16 |
+
"mrr_at_1000": 0.28747,
|
17 |
+
"mrr_at_3": 0.24583,
|
18 |
+
"mrr_at_5": 0.26143,
|
19 |
+
"ndcg_at_1": 0.194,
|
20 |
+
"ndcg_at_10": 0.15264,
|
21 |
+
"ndcg_at_100": 0.2263,
|
22 |
+
"ndcg_at_1000": 0.28559,
|
23 |
+
"ndcg_at_3": 0.14425,
|
24 |
+
"ndcg_at_5": 0.1252,
|
25 |
+
"precision_at_1": 0.194,
|
26 |
+
"precision_at_10": 0.0781,
|
27 |
+
"precision_at_100": 0.01854,
|
28 |
+
"precision_at_1000": 0.00329,
|
29 |
+
"precision_at_3": 0.131,
|
30 |
+
"precision_at_5": 0.1068,
|
31 |
+
"recall_at_1": 0.03938,
|
32 |
+
"recall_at_10": 0.15903,
|
33 |
+
"recall_at_100": 0.37645,
|
34 |
+
"recall_at_1000": 0.6686,
|
35 |
+
"recall_at_3": 0.07993,
|
36 |
+
"recall_at_5": 0.10885
|
37 |
+
}
|
38 |
+
}
|
mteb_results/SICK-R.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "a6ea5a8cab320b040a23452cc28066d9beae2cee",
|
3 |
+
"mteb_dataset_name": "SICK-R",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"cos_sim": {
|
7 |
+
"pearson": 0.8012689060151424,
|
8 |
+
"spearman": 0.7046515535094772
|
9 |
+
},
|
10 |
+
"euclidean": {
|
11 |
+
"pearson": 0.7717160003557223,
|
12 |
+
"spearman": 0.704651757047438
|
13 |
+
},
|
14 |
+
"evaluation_time": 7.91,
|
15 |
+
"manhattan": {
|
16 |
+
"pearson": 0.7718129609281936,
|
17 |
+
"spearman": 0.7046610403752913
|
18 |
+
}
|
19 |
+
}
|
20 |
+
}
|
mteb_results/STS12.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "a0d554a64d88156834ff5ae9920b964011b16384",
|
3 |
+
"mteb_dataset_name": "STS12",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"cos_sim": {
|
7 |
+
"pearson": 0.70451157033355,
|
8 |
+
"spearman": 0.6399899601697853
|
9 |
+
},
|
10 |
+
"euclidean": {
|
11 |
+
"pearson": 0.6746985359967678,
|
12 |
+
"spearman": 0.6400001637764805
|
13 |
+
},
|
14 |
+
"evaluation_time": 2.34,
|
15 |
+
"manhattan": {
|
16 |
+
"pearson": 0.6756534741780037,
|
17 |
+
"spearman": 0.6406533893575366
|
18 |
+
}
|
19 |
+
}
|
20 |
+
}
|
mteb_results/STS13.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "7e90230a92c190f1bf69ae9002b8cea547a64cca",
|
3 |
+
"mteb_dataset_name": "STS13",
|
4 |
+
"mteb_version": "1.1.0",
|
5 |
+
"test": {
|
6 |
+
"cos_sim": {
|
7 |
+
"pearson": 0.7765086614464292,
|
8 |
+
"spearman": 0.7820169706921849
|
9 |
+
},
|
10 |
+
"euclidean": {
|
11 |
+
"pearson": 0.7777758172155284,
|
12 |
+
"spearman": 0.7820169706921849
|
13 |
+
},
|
14 |
+
"evaluation_time": 1.03,
|
15 |
+
"manhattan": {
|
16 |
+
"pearson": 0.7775077884860052,
|
17 |
+
"spearman": 0.7816875216484164
|
18 |
+
}
|
19 |
+
}
|
20 |
+
}
|