File size: 2,126 Bytes
57ab82e
 
 
775f892
 
 
3d1e5df
775f892
 
 
 
8d5ea94
 
 
 
775f892
 
 
 
 
 
 
 
 
 
 
5e9e894
775f892
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2435db0
775f892
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d1e5df
 
 
 
 
 
775f892
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
---
license: llama2
---

# RankLLaMA-7B-Document

[Fine-Tuning LLaMA for Multi-Stage Text Retrieval](https://arxiv.org/abs/2310.08319).
Xueguang Ma, Liang Wang, Nan Yang, Furu Wei, Jimmy Lin, arXiv 2023

This model is fine-tuned from LLaMA-2-7B using LoRA for document reranking, this model takes input length upto 4096 tokens.

## Training Data
The model is fine-tuned on the training split of [MS MARCO Document Ranking](https://microsoft.github.io/msmarco/Datasets) datasets for 1 epoch.
Please check our paper for details.

## Usage

Below is an example to compute the similarity score of a query-document pair

```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel, PeftConfig

def get_model(peft_model_name):
    config = PeftConfig.from_pretrained(peft_model_name)
    base_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path, num_labels=1)
    model = PeftModel.from_pretrained(base_model, peft_model_name)
    model = model.merge_and_unload()
    model.eval()
    return model

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
model = get_model('castorini/rankllama-v1-7b-lora-doc')

# Define a query-document pair
query = "What is llama?"
url = "https://en.wikipedia.org/wiki/Llama"
title = "Llama"
document = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."

# Tokenize the query-document pair
inputs = tokenizer(f'query: {query}', f'document: {url} {title} {document}', return_tensors='pt')

# Run the model forward
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    score = logits[0][0]
    print(score)

```

## Citation

If you find our paper or models helpful, please consider cite as follows:

```
@article{rankllama,
      title={Fine-Tuning LLaMA for Multi-Stage Text Retrieval}, 
      author={Xueguang Ma and Liang Wang and Nan Yang and Furu Wei and Jimmy Lin},
      year={2023},
      journal={arXiv:2310.08319},
}
```