File size: 3,021 Bytes
16f9e8b
89c73fc
 
7267c23
89c73fc
7267c23
 
 
3c0acdb
89c73fc
 
af988d9
89c73fc
 
 
 
 
cee8c26
 
16f9e8b
 
7267c23
fc25406
16f9e8b
c8aaa5a
 
 
16f9e8b
 
7267c23
16f9e8b
0ea5861
 
ed47934
7267c23
ed47934
16f9e8b
0ea5861
16f9e8b
 
7267c23
16f9e8b
 
7267c23
16f9e8b
 
 
7267c23
890defe
fc76ec6
7267c23
 
16f9e8b
7267c23
16f9e8b
7267c23
 
 
16f9e8b
7267c23
 
 
16f9e8b
7267c23
 
 
 
 
 
cee8c26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
---
language:
- en
license: apache-2.0
library_name: transformers
datasets:
- ms_marco
pipeline_tag: text2text-generation
widget:
- text: how to bake perfect cookie
  pipeline_tag: text2text-generation
inference_config:
  generation_config:
    max_length: 35
    num_beams: 1
    do_sample: true
    repetition_penalty: 1.8
tags:
- code
---

## Model Summary
This is a generative model designed specifically for search query rewriting, employing a sequence-to-sequence architecture for generating reformulated queries. It leverages a Reinforcement Learning framework to further boost performance, integrating a policy gradient algorithm. The model is trained with reward functions aimed at diversifying the generated queries by paraphrasing keywords. It can be integrated with sparse retrieval methods, such as bm25-based retrieval, to enhance document recall in search.

### Intended use cases
Query rewriting for search (web, e-commerce), Virtual assistants and chatbots, Information retrieval

### Model Description

Training Procedure

1. The training process begins by initializing the sequence-to-sequence model with Google's [T5-base model ](https://huggingface.co/google-t5/t5-base).
2. Initially, the model undergoes supervised training using the [MS-MARCO query pairs dataset](https://github.com/Narabzad/msmarco-query-reformulation/tree/main/datasets/queries)
3. Subsequently, the model is fine-tuned using a reinforcement learning (RL) framework to enhance its ability to generate queries that are both diverse and relevant.
4. It uses a policy gradient approach to fine-tune the model. For a given input query, a set of trajectories (reformulated queries) are sampled from the model and reward is computed. Policy gradient algorithm is applied to update the model.
5. Rewards are heuristically computed to enhance the model's paraphrasing capability. However, these rewards can be substituted with other domain-specific or goal-specific reward functions as needed.

Refer [here](https://github.com/PraveenSH/RL-Query-Reformulation) for more details.


### Model Sources


- **Repository:** https://github.com/PraveenSH/RL-Query-Reformulation



### How to use
For optimal utilization of this model, use sampling with repetition penalty to generate diverse samples. Below is the provided sample code.
```python
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

MODEL_ID = "prhegde/t5-query-reformulation-RL"

tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()

input_sequence = "how to bake great cookie"
input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
print(f'Input: {input_sequence}')

nsent = 4
with torch.no_grad():
    for i in range(nsent):
        output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8)
        target_sequence = tokenizer.decode(output[0], skip_special_tokens=True)
        print(f'Target: {target_sequence}')
```