Update README.md
Browse files
README.md
CHANGED
@@ -43,8 +43,100 @@ For more detailed discussions, please check out our [blog post](starling.cs.berk
|
|
43 |
Please use the following code for inference with the reward model.
|
44 |
|
45 |
```
|
|
|
|
|
|
|
|
|
|
|
46 |
## Define the reward model function class
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
```
|
49 |
|
50 |
|
|
|
43 |
Please use the following code for inference with the reward model.
|
44 |
|
45 |
```
|
46 |
+
import os
|
47 |
+
import torch
|
48 |
+
from torch import nn
|
49 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
50 |
+
|
51 |
## Define the reward model function class
|
52 |
+
|
53 |
+
class GPTRewardModel(nn.Module):
|
54 |
+
def __init__(self, model_path):
|
55 |
+
super().__init__()
|
56 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
57 |
+
self.config = model.config
|
58 |
+
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
|
59 |
+
self.model = model
|
60 |
+
self.transformer = model.model
|
61 |
+
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
|
62 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
63 |
+
self.tokenizer.pad_token = self.tokenizer.unk_token
|
64 |
+
self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
|
65 |
+
|
66 |
+
def get_device(self):
|
67 |
+
return self.model.device
|
68 |
+
|
69 |
+
def forward(
|
70 |
+
self,
|
71 |
+
input_ids=None,
|
72 |
+
past_key_values=None,
|
73 |
+
attention_mask=None,
|
74 |
+
position_ids=None,
|
75 |
+
):
|
76 |
+
"""
|
77 |
+
input_ids, attention_mask: torch.Size([bs, seq_len])
|
78 |
+
return: scores: List[bs]
|
79 |
+
"""
|
80 |
+
bs = input_ids.shape[0]
|
81 |
+
transformer_outputs = self.transformer(
|
82 |
+
input_ids,
|
83 |
+
past_key_values=past_key_values,
|
84 |
+
attention_mask=attention_mask,
|
85 |
+
position_ids=position_ids,
|
86 |
+
)
|
87 |
+
hidden_states = transformer_outputs[0]
|
88 |
+
scores = []
|
89 |
+
rewards = self.v_head(hidden_states).squeeze(-1)
|
90 |
+
for i in range(bs):
|
91 |
+
c_inds = (input_ids[i] == self.PAD_ID).nonzero()
|
92 |
+
c_ind = c_inds[0].item() if len(c_inds) > 0 else input_ids.shape[1]
|
93 |
+
scores.append(rewards[i, c_ind - 1])
|
94 |
+
return scores
|
95 |
+
return scores
|
96 |
+
|
97 |
+
## Load the model and tokenizer
|
98 |
+
|
99 |
+
reward_model = GPTRewardModel("meta-llama/Llama-2-7b-chat-hf", reward_tokenizer.eos_token_id)
|
100 |
+
reward_tokenizer = reward_model.tokenizer
|
101 |
+
reward_tokenizer.truncation_side = "left"
|
102 |
+
|
103 |
+
directory = snapshot_download("berkeley-nest/Starling-RM-7B-alpha")
|
104 |
+
for fpath in os.listdir(directory):
|
105 |
+
if fpath.endswith(".pt") or fpath.endswith("model.bin"):
|
106 |
+
checkpoint = os.path.join(directory, fpath)
|
107 |
+
break
|
108 |
+
|
109 |
+
reward_model.load_state_dict(torch.load(checkpoint), strict=False)
|
110 |
+
reward_model.eval().requires_grad_(False)
|
111 |
+
|
112 |
+
|
113 |
+
## Define the reward function
|
114 |
+
|
115 |
+
def get_reward(samples):
|
116 |
+
"""samples: List[str]"""
|
117 |
+
input_ids = []
|
118 |
+
attention_masks = []
|
119 |
+
encodings_dict = reward_tokenizer(
|
120 |
+
samples,
|
121 |
+
truncation=True,
|
122 |
+
max_length=2048,
|
123 |
+
padding="max_length",
|
124 |
+
return_tensors="pt",
|
125 |
+
).to(reward_device)
|
126 |
+
input_ids = encodings_dict["input_ids"]
|
127 |
+
attention_masks = encodings_dict["attention_mask"]
|
128 |
+
mbs = reward_batch_size
|
129 |
+
out = []
|
130 |
+
for i in range(math.ceil(len(samples) / mbs)):
|
131 |
+
rewards = reward_model(input_ids=input_ids[i * mbs : (i + 1) * mbs], attention_mask=attention_masks[i * mbs : (i + 1) * mbs])
|
132 |
+
out.extend(rewards)
|
133 |
+
return torch.hstack(out)
|
134 |
+
|
135 |
+
## Inference over test prompts with llama2 chat template
|
136 |
+
|
137 |
+
test_sample = ["<s>[INST] Hello? </s> [/INST] Hi, how can I help you?</s>"]
|
138 |
+
reward_for_test_sample = get_reward(test_sample)
|
139 |
+
print(reward_for_test_sample)
|
140 |
```
|
141 |
|
142 |
|