BAAI
/

Shitao commited on
Commit
f03b549
·
verified ·
1 Parent(s): 56eba4c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +123 -0
README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: sentence-similarity
3
+ tags:
4
+ - sentence-transformers
5
+ - feature-extraction
6
+ - sentence-similarity
7
+ license: mit
8
+ ---
9
+
10
+ For more details please refer to our github repo: https://github.com/FlagOpen/FlagEmbedding
11
+
12
+ # LLARA ([paper](https://arxiv.org/pdf/2312.15503))
13
+
14
+ In this project, we introduce LLaRA:
15
+ - EBAE: Embedding-Based Auto-Encoding.
16
+ - EBAR: Embedding-Based Auto-Regression.
17
+
18
+
19
+ ## Usage
20
+
21
+ ```
22
+ import torch
23
+ from transformers import AutoModel, AutoTokenizer, LlamaModel
24
+
25
+ def get_query_inputs(queries, tokenizer, max_length=512):
26
+ prefix = '"'
27
+ suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>'
28
+ prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
29
+ suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
30
+ queries_inputs = []
31
+ for query in queries:
32
+ inputs = tokenizer(query,
33
+ return_tensors=None,
34
+ max_length=max_length,
35
+ truncation=True,
36
+ add_special_tokens=False)
37
+ inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
38
+ inputs['attention_mask'] = [1] * len(inputs['input_ids'])
39
+ queries_inputs.append(inputs)
40
+ return tokenizer.pad(
41
+ queries_inputs,
42
+ padding=True,
43
+ max_length=max_length,
44
+ pad_to_multiple_of=8,
45
+ return_tensors='pt',
46
+ )
47
+
48
+ def get_passage_inputs(passages, tokenizer, max_length=512):
49
+ prefix = '"'
50
+ suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>'
51
+ prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
52
+ suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
53
+ passages_inputs = []
54
+ for passage in passages:
55
+ inputs = tokenizer(passage,
56
+ return_tensors=None,
57
+ max_length=max_length,
58
+ truncation=True,
59
+ add_special_tokens=False)
60
+ inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
61
+ inputs['attention_mask'] = [1] * len(inputs['input_ids'])
62
+ passages_inputs.append(inputs)
63
+ return tokenizer.pad(
64
+ passages_inputs,
65
+ padding=True,
66
+ max_length=max_length,
67
+ pad_to_multiple_of=8,
68
+ return_tensors='pt',
69
+ )
70
+
71
+ # Load the tokenizer and model
72
+ tokenizer = AutoTokenizer.from_pretrained('BAAI/LLARA-document')
73
+ model = AutoModel.from_pretrained('BAAI/LLARA-document')
74
+
75
+ # Define query and passage inputs
76
+ query = "What is llama?"
77
+ title = "Llama"
78
+ passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
79
+ query_input = get_query_inputs([query], tokenizer)
80
+ passage_input = get_passage_inputs([passage], tokenizer)
81
+
82
+
83
+ with torch.no_grad():
84
+ # compute query embedding
85
+ query_outputs = model(**query_input, return_dict=True, output_hidden_states=True)
86
+ query_embedding = query_outputs.hidden_states[-1][:, -8:, :]
87
+ query_embedding = torch.mean(query_embedding, dim=1)
88
+ query_embedding = torch.nn.functional.normalize(query_embedding, dim=-1)
89
+
90
+ # compute passage embedding
91
+ passage_outputs = model(**passage_input, return_dict=True, output_hidden_states=True)
92
+ passage_embeddings = passage_outputs.hidden_states[-1][:, -8:, :]
93
+ passage_embeddings = torch.mean(passage_embeddings, dim=1)
94
+ passage_embeddings = torch.nn.functional.normalize(passage_embeddings, dim=-1)
95
+
96
+ # compute similarity score
97
+ score = query_embedding @ passage_embeddings.T
98
+ print(score)
99
+
100
+ ```
101
+
102
+
103
+ ## Acknowledgement
104
+
105
+ Thanks to the authors of open-sourced datasets, including MSMARCO, BEIR, etc.
106
+ Thanks to the open-sourced libraries like [Pyserini](https://github.com/castorini/pyserini).
107
+
108
+
109
+
110
+ ## Citation
111
+
112
+ If you find this repository useful, please consider giving a star :star: and citation
113
+
114
+ ```
115
+ @misc{li2023making,
116
+ title={Making Large Language Models A Better Foundation For Dense Retrieval},
117
+ author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao},
118
+ year={2023},
119
+ eprint={2312.15503},
120
+ archivePrefix={arXiv},
121
+ primaryClass={cs.CL}
122
+ }
123
+ ```