PEFT
Safetensors
English
File size: 3,651 Bytes
d610bda
 
 
260e052
 
 
de20bc7
 
d610bda
 
260e052
e172057
d610bda
260e052
d610bda
 
 
260e052
 
 
d610bda
260e052
d610bda
260e052
e172057
260e052
d610bda
260e052
d610bda
260e052
d610bda
 
260e052
d610bda
 
 
260e052
d610bda
 
 
260e052
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d610bda
 
70e4dd2
d610bda
260e052
d610bda
70e4dd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260e052
d610bda
260e052
e172057
 
 
 
 
 
 
 
 
260e052
d610bda
260e052
d610bda
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
---
library_name: peft
base_model: mistralai/Mistral-7B-Instruct-v0.2
license: apache-2.0
language:
- en
datasets:
- chtmp223/suri
---

# Suri-SFT
Suri-SFT is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.2 using supervised fine-tuning with LoRA. Please check [our paper](https://arxiv.org/abs/2406.19371) for more details on the method. 

## 📒 Model Details

### Model Description

- **Language(s) (NLP):** English
- **License:** Apache-2.0
- **Finetuned from model:** [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)

### Model Sources

- **Repository:** [Github repository](https://github.com/chtmp223/suri) -- contains code to reconstruct books3 subset. 
- **Paper:** [Link](https://arxiv.org/abs/2406.19371)
- **Demo:** [Website](https://chtmp223.github.io/suri)

## ⚠️ Getting Started

Use the code in [this repository](https://github.com/chtmp223/suri) for training and inference. 


## 💻 Training Details

### Training Data

[chtmp223/suri](https://huggingface.co/datasets/chtmp223/suri)

### Training Procedure

| **Configurations**               | **Values**   |
|----------------------------------|--------------|
| Hardware (Training and Inference)| 4xA100s      |
| Tracking                         | wandb        |
| lora_r                           | 16           |
| lora_alpha                       | 16           |
| lora_dropout                     | 0.05         |
| gradient_accumulation_steps      | 1            |
| gradient_checkpointing           | True         |
| learning_rate                    | 5.0e-5       |
| lr_scheduler_type                | cosine       |
| max_length                       | 15024        |
| max_completion_length            | 15000        |
| max_prompt_length                | 5000         |
| num_train_epochs                 | 2            |
| optim                            | adamw_torch  |
| per_device_train_batch_size      | 1            |


#### Software

Training code is adapted from [Alignment Handbook](https://github.com/huggingface/alignment-handbook) and [Trl](https://github.com/huggingface/trl).

## 🤗 Inference
```
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from datasets import load_dataset
import torch
os.environ["TOKENIZERS_PARALLELISM"] = "False"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

model_name = "chtmp223/suri-sft"
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
config = PeftConfig.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
model = PeftModel.from_pretrained(base_model, model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
prompt = [
  {
      "role": "user",
      "content": user_prompt, 
  }
]
input_context = tokenizer.apply_chat_template(
  prompt, add_generation_prompt=True, tokenize=False
)
input_ids = tokenizer.encode(
  input_context, return_tensors="pt", add_special_tokens=False
).to(model.device)
output = model.generate(
  input_ids, max_length=10000, do_sample=True, use_cache=True
).cpu()

print(tokenizer.decode(output[0]))
```


## 📜 Citation 

```
@misc{pham2024surimulticonstraintinstructionfollowing,
      title={Suri: Multi-constraint Instruction Following for Long-form Text Generation}, 
      author={Chau Minh Pham and Simeng Sun and Mohit Iyyer},
      year={2024},
      eprint={2406.19371},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2406.19371}, 
}
```

### ⚙️ Framework versions

- PEFT 0.11.1