File size: 4,290 Bytes
9138efc
73ef392
 
 
 
 
 
9138efc
a68cb4c
73ef392
a68cb4c
73ef392
a68cb4c
73ef392
 
 
a68cb4c
 
73ef392
a68cb4c
73ef392
05926ac
 
 
e2ee5b6
05926ac
 
 
 
e2ee5b6
 
 
 
 
 
 
 
 
 
05926ac
e2ee5b6
05926ac
 
 
 
 
 
e2ee5b6
05926ac
 
 
 
e2ee5b6
 
 
 
 
 
 
 
 
 
05926ac
e2ee5b6
05926ac
 
 
 
 
 
 
73ef392
 
05926ac
 
 
 
 
e2ee5b6
 
05926ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a68cb4c
73ef392
 
 
 
 
 
 
 
 
 
 
a68cb4c
73ef392
e2ee5b6
73ef392
 
 
 
 
 
 
 
 
a68cb4c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
---
pipeline_tag: sentence-similarity
tags:
- sentence-transformers
- feature-extraction
- sentence-similarity
license: mit
---

For more details please refer to our github repo: https://github.com/FlagOpen/FlagEmbedding

# LLARA ([paper](https://arxiv.org/pdf/2312.15503))

In this project, we introduce LLaRA:
- EBAE: Embedding-Based Auto-Encoding.
- EBAR: Embedding-Based Auto-Regression. 


## Usage

```
import torch
from transformers import AutoModel, AutoTokenizer, LlamaModel

def get_query_inputs(queries, tokenizer, max_length=512):
    prefix = '"'
    suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>'
    prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
    suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
    queries_inputs = []
    for query in queries:
        inputs = tokenizer(query,
                           return_tensors=None,
                           max_length=max_length,
                           truncation=True,
                           add_special_tokens=False)
        inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
        inputs['attention_mask'] = [1] * len(inputs['input_ids'])
        queries_inputs.append(inputs)
    return tokenizer.pad(
            queries_inputs,
            padding=True,
            max_length=max_length,
            pad_to_multiple_of=8,
            return_tensors='pt',
        )

def get_passage_inputs(passages, tokenizer, max_length=512):
    prefix = '"'
    suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>'
    prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
    suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
    passages_inputs = []
    for passage in passages:
        inputs = tokenizer(passage,
                           return_tensors=None,
                           max_length=max_length,
                           truncation=True,
                           add_special_tokens=False)
        inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
        inputs['attention_mask'] = [1] * len(inputs['input_ids'])
        passages_inputs.append(inputs)
    return tokenizer.pad(
            passages_inputs,
            padding=True,
            max_length=max_length,
            pad_to_multiple_of=8,
            return_tensors='pt',
        )

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('BAAI/LLARA-beir')
model = AutoModel.from_pretrained('BAAI/LLARA-beir')

# Define query and passage inputs
query = "What is llama?"
title = "Llama"
passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
query_input = get_query_inputs([query], tokenizer)
passage_input = get_passage_inputs([passage], tokenizer)


with torch.no_grad():
    # compute query embedding
    query_outputs = model(**query_input, return_dict=True, output_hidden_states=True)
    query_embedding = query_outputs.hidden_states[-1][:, -8:, :]
    query_embedding = torch.mean(query_embedding, dim=1)
    query_embedding = torch.nn.functional.normalize(query_embedding, dim=-1)

    # compute passage embedding
    passage_outputs = model(**passage_input, return_dict=True, output_hidden_states=True)
    passage_embeddings = passage_outputs.hidden_states[-1][:, -8:, :]
    passage_embeddings = torch.mean(passage_embeddings, dim=1)
    passage_embeddings = torch.nn.functional.normalize(passage_embeddings, dim=-1)

    # compute similarity score
    score = query_embedding @ passage_embeddings.T
    print(score)

```


## Acknowledgement

Thanks to the authors of open-sourced datasets, including MSMARCO, BEIR, etc. 
Thanks to the open-sourced libraries like [Pyserini](https://github.com/castorini/pyserini).



## Citation

If you find this repository useful, please consider giving a star :star: and citation

```
@misc{li2023making,
      title={Making Large Language Models A Better Foundation For Dense Retrieval}, 
      author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao},
      year={2023},
      eprint={2312.15503},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
```