File size: 3,292 Bytes
fbcff2d
 
 
4c1b998
 
 
fbcff2d
 
4c1b998
 
fbcff2d
4c1b998
fbcff2d
 
 
4c1b998
 
 
fbcff2d
4c1b998
fbcff2d
4c1b998
 
 
fbcff2d
4c1b998
fbcff2d
4c1b998
fbcff2d
 
4c1b998
fbcff2d
 
 
4c1b998
fbcff2d
 
 
4c1b998
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18e5293
4c1b998
 
 
18e5293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c1b998
 
 
 
 
 
 
fbcff2d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
---
library_name: peft
base_model: mistralai/Mistral-7B-Instruct-v0.2
license: apache-2.0
language:
- en
---

# Suri-I-ORPO
Suri-I-ORPO is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.2 using instructional odds ratio preference optimization (I-ORPO). Please check [our paper](TODO) for more details on the method. 

## 📒 Model Details

### Model Description

- **Language(s) (NLP):** English
- **License:** Apache-2.0
- **Finetuned from model:** [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)

### Model Sources

- **Repository:** [Github repository](https://github.com/chtmp223/suri) -- contains code to reconstruct books3 subset. 
- **Paper:** TODO
- **Demo:** [Website](https://chtmp223.github.io/suri)

## ⚠️ Getting Started

Use the code in [this repository](https://github.com/chtmp223/suri) for training and inference. 


## 💻 Training Details

### Training Data

[chtmp223/suri](https://huggingface.co/datasets/chtmp223/suri)

### Training Procedure

| **Configurations**               | **Values**   |
|----------------------------------|--------------|
| Hardware (Training and Inference)| 4xA100s      |
| Tracking                         | wandb        |
| lora_r                           | 16           |
| lora_alpha                       | 16           |
| lora_dropout                     | 0.05         |
| beta                             | 0.4          |
| gradient_accumulation_steps      | 1            |
| gradient_checkpointing           | True         |
| learning_rate                    | 5.0e-5       |
| lr_scheduler_type                | cosine       |
| max_length                       | 15024        |
| max_completion_length            | 15000        |
| max_prompt_length                | 5000         |
| num_train_epochs                 | 2            |
| optim                            | adamw_torch  |
| per_device_train_batch_size      | 1            |

#### Software

Training code is adapted from [Alignment Handbook](https://github.com/huggingface/alignment-handbook) and [Trl](https://github.com/huggingface/trl).

## 🤗 Inference

```
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from datasets import load_dataset
import torch
os.environ["TOKENIZERS_PARALLELISM"] = "False"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

model_name = "chtmp223/suri-i-orpo"
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
config = PeftConfig.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
model = PeftModel.from_pretrained(base_model, model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
prompt = [
  {
      "role": "user",
      "content": user_prompt, 
  }
]
input_context = tokenizer.apply_chat_template(
  prompt, add_generation_prompt=True, tokenize=False
)
input_ids = tokenizer.encode(
  input_context, return_tensors="pt", add_special_tokens=False
).to(model.device)
output = model.generate(
  input_ids, max_length=10000, do_sample=True, use_cache=True
).cpu()

print(tokenizer.decode(output[0]))
```


## 📜 Citation 

```
TODO
```

### ⚙️ Framework versions

- PEFT 0.11.1