|
--- |
|
language: |
|
- en |
|
license: apache-2.0 |
|
library_name: transformers |
|
datasets: |
|
- ms_marco |
|
pipeline_tag: text2text-generation |
|
widget: |
|
- text: how to bake perfect cookie |
|
pipeline_tag: text2text-generation |
|
inference_config: |
|
generation_config: |
|
max_length: 35 |
|
num_beams: 1 |
|
do_sample: true |
|
repetition_penalty: 1.8 |
|
tags: |
|
- code |
|
--- |
|
|
|
## Model Summary |
|
This is a generative model designed specifically for search query rewriting, employing a sequence-to-sequence architecture for generating reformulated queries. It leverages a Reinforcement Learning framework to further boost performance, integrating a policy gradient algorithm. The model is trained with reward functions aimed at diversifying the generated queries by paraphrasing keywords. It can be integrated with sparse retrieval methods, such as bm25-based retrieval, to enhance document recall in search. |
|
|
|
### Intended use cases |
|
Query rewriting for search (web, e-commerce), Virtual assistants and chatbots, Information retrieval |
|
|
|
### Model Description |
|
|
|
Training Procedure |
|
|
|
1. The training process begins by initializing the sequence-to-sequence model with Google's [T5-base model ](https://huggingface.co/google-t5/t5-base). |
|
2. Initially, the model undergoes supervised training using the [MS-MARCO query pairs dataset](https://github.com/Narabzad/msmarco-query-reformulation/tree/main/datasets/queries) |
|
3. Subsequently, the model is fine-tuned using a reinforcement learning (RL) framework to enhance its ability to generate queries that are both diverse and relevant. |
|
4. It uses a policy gradient approach to fine-tune the model. For a given input query, a set of trajectories (reformulated queries) are sampled from the model and reward is computed. Policy gradient algorithm is applied to update the model. |
|
5. Rewards are heuristically computed to enhance the model's paraphrasing capability. However, these rewards can be substituted with other domain-specific or goal-specific reward functions as needed. |
|
|
|
Refer [here](https://github.com/PraveenSH/RL-Query-Reformulation) for more details. |
|
|
|
|
|
### Model Sources |
|
|
|
|
|
- **Repository:** https://github.com/PraveenSH/RL-Query-Reformulation |
|
|
|
|
|
|
|
### How to use |
|
For optimal utilization of this model, use sampling with repetition penalty to generate diverse samples. Below is the provided sample code. |
|
```python |
|
import torch |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
|
|
MODEL_ID = "prhegde/t5-query-reformulation-RL" |
|
|
|
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID) |
|
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID) |
|
model.eval() |
|
|
|
input_sequence = "how to bake great cookie" |
|
input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids |
|
print(f'Input: {input_sequence}') |
|
|
|
nsent = 4 |
|
with torch.no_grad(): |
|
for i in range(nsent): |
|
output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8) |
|
target_sequence = tokenizer.decode(output[0], skip_special_tokens=True) |
|
print(f'Target: {target_sequence}') |
|
``` |