Raphael Sourty commited on
Commit
ac13da9
·
unverified ·
1 Parent(s): a2b3335

initialize model

Browse files
README.md CHANGED
@@ -1,3 +1,140 @@
1
  ---
 
 
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
  license: mit
5
  ---
6
+
7
+ This model was trained with [Neural-Cherche](https://github.com/raphaelsty/neural-cerche). You can find details on how to fine-tune it in the [Neural-Cherche](https://github.com/raphaelsty/neural-cherche) repository.
8
+
9
+ This model is an `all-mpnet-base-v2` as a ColBERT.
10
+
11
+
12
+ ```sh
13
+ pip install neural-cherche
14
+ ```
15
+
16
+ ## Retriever
17
+
18
+ ```python
19
+ from neural_cherche import models, retrieve
20
+ import torch
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ batch_size = 32
24
+
25
+ documents = [
26
+ {"id": 0, "document": "Food"},
27
+ {"id": 1, "document": "Sports"},
28
+ {"id": 2, "document": "Cinema"},
29
+ ]
30
+
31
+ queries = ["Food", "Sports", "Cinema"]
32
+
33
+ model = models.ColBERT(
34
+ model_name_or_path="raphaelsty",
35
+ device=device,
36
+ )
37
+
38
+ retriever = retrieve.ColBERT(
39
+ key="id",
40
+ on=["document"],
41
+ model=model,
42
+ )
43
+
44
+ documents_embeddings = retriever.encode_documents(
45
+ documents=documents,
46
+ batch_size=batch_size,
47
+ )
48
+
49
+ retriever = retriever.add(
50
+ documents_embeddings=documents_embeddings,
51
+ )
52
+
53
+ queries_embeddings = retriever.encode_queries(
54
+ queries=queries,
55
+ batch_size=batch_size,
56
+ )
57
+
58
+ scores = retriever(
59
+ queries_embeddings=queries_embeddings,
60
+ batch_size=batch_size,
61
+ k=3,
62
+ )
63
+
64
+ scores
65
+ ```
66
+
67
+ ## Ranker
68
+
69
+ ```python
70
+ from neural_cherche import models, rank, retrieve
71
+ import torch
72
+
73
+ device = "cuda" if torch.cuda.is_available() else "cpu"
74
+ batch_size = 32
75
+
76
+ documents = [
77
+ {"id": "doc1", "title": "Paris", "text": "Paris is the capital of France."},
78
+ {"id": "doc2", "title": "Montreal", "text": "Montreal is the largest city in Quebec."},
79
+ {"id": "doc3", "title": "Bordeaux", "text": "Bordeaux in Southwestern France."},
80
+ ]
81
+
82
+ queries = [
83
+ "What is the capital of France?",
84
+ "What is the largest city in Quebec?",
85
+ "Where is Bordeaux?",
86
+ ]
87
+
88
+ retriever = retrieve.TfIdf(
89
+ key="id",
90
+ on=["title", "text"],
91
+ )
92
+
93
+ model = models.ColBERT(
94
+ model_name_or_path="sentence-transformers/all-mpnet-base-v2",
95
+ device=device,
96
+ )
97
+
98
+ ranker = rank.ColBERT(
99
+ key="id",
100
+ on=["title", "text"],
101
+ model=model
102
+ )
103
+
104
+ retriever_documents_embeddings = retriever.encode_documents(
105
+ documents=documents,
106
+ )
107
+
108
+ retriever.add(
109
+ documents_embeddings=retriever_documents_embeddings,
110
+ )
111
+
112
+ ranker_documents_embeddings = ranker.encode_documents(
113
+ documents=documents,
114
+ batch_size=batch_size,
115
+ )
116
+
117
+ retriever_queries_embeddings = retriever.encode_queries(
118
+ queries=queries,
119
+ )
120
+
121
+ ranker_queries_embeddings = ranker.encode_queries(
122
+ queries=queries,
123
+ batch_size=batch_size,
124
+ )
125
+
126
+ candidates = retriever(
127
+ queries_embeddings=retriever_queries_embeddings,
128
+ k=1000,
129
+ )
130
+
131
+ scores = ranker(
132
+ documents=candidates,
133
+ queries_embeddings=ranker_queries_embeddings,
134
+ documents_embeddings=ranker_documents_embeddings,
135
+ k=100,
136
+ batch_size=32,
137
+ )
138
+
139
+ scores
140
+ ```
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sentence-transformers/all-mpnet-base-v2",
3
+ "architectures": [
4
+ "MPNetForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "mpnet",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "output_hidden_states": true,
20
+ "pad_token_id": 1,
21
+ "relative_attention_num_buckets": 32,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.35.2",
24
+ "vocab_size": 30527
25
+ }
linear.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51c6fe3495544322ca5339d03f91afe54f6234d441f5b69b6f062f483fc9ce99
3
+ size 394391
metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"max_length_query": 32, "max_length_document": 256, "query_prefix": "[Q] ", "document_prefix": "[D] "}
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:987c71696a5e133e2dad24165c4839b9cf7cf1744e30cfe4eb3a2fcd077cf4eb
3
+ size 438097372
special_tokens_map.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": "<pad>",
31
+ "sep_token": {
32
+ "content": "</s>",
33
+ "lstrip": false,
34
+ "normalized": true,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ },
38
+ "unk_token": {
39
+ "content": "[UNK]",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false
44
+ }
45
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "104": {
36
+ "content": "[UNK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "30526": {
44
+ "content": "<mask>",
45
+ "lstrip": true,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ }
51
+ },
52
+ "bos_token": "<s>",
53
+ "clean_up_tokenization_spaces": true,
54
+ "cls_token": "<s>",
55
+ "device": "cuda",
56
+ "do_lower_case": true,
57
+ "eos_token": "</s>",
58
+ "mask_token": "<mask>",
59
+ "max_length": 128,
60
+ "model_max_length": 512,
61
+ "pad_to_multiple_of": null,
62
+ "pad_token": "<pad>",
63
+ "pad_token_type_id": 0,
64
+ "padding_side": "right",
65
+ "sep_token": "</s>",
66
+ "stride": 0,
67
+ "strip_accents": null,
68
+ "tokenize_chinese_chars": true,
69
+ "tokenizer_class": "MPNetTokenizer",
70
+ "truncation_side": "right",
71
+ "truncation_strategy": "longest_first",
72
+ "unk_token": "[UNK]"
73
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff