banghua commited on
Commit
73a3756
1 Parent(s): b5fc156

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +93 -1
README.md CHANGED
@@ -43,8 +43,100 @@ For more detailed discussions, please check out our [blog post](starling.cs.berk
43
  Please use the following code for inference with the reward model.
44
 
45
  ```
 
 
 
 
 
46
  ## Define the reward model function class
47
- Test.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ```
49
 
50
 
 
43
  Please use the following code for inference with the reward model.
44
 
45
  ```
46
+ import os
47
+ import torch
48
+ from torch import nn
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer
50
+
51
  ## Define the reward model function class
52
+
53
+ class GPTRewardModel(nn.Module):
54
+ def __init__(self, model_path):
55
+ super().__init__()
56
+ model = AutoModelForCausalLM.from_pretrained(model_path)
57
+ self.config = model.config
58
+ self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
59
+ self.model = model
60
+ self.transformer = model.model
61
+ self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
62
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
63
+ self.tokenizer.pad_token = self.tokenizer.unk_token
64
+ self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
65
+
66
+ def get_device(self):
67
+ return self.model.device
68
+
69
+ def forward(
70
+ self,
71
+ input_ids=None,
72
+ past_key_values=None,
73
+ attention_mask=None,
74
+ position_ids=None,
75
+ ):
76
+ """
77
+ input_ids, attention_mask: torch.Size([bs, seq_len])
78
+ return: scores: List[bs]
79
+ """
80
+ bs = input_ids.shape[0]
81
+ transformer_outputs = self.transformer(
82
+ input_ids,
83
+ past_key_values=past_key_values,
84
+ attention_mask=attention_mask,
85
+ position_ids=position_ids,
86
+ )
87
+ hidden_states = transformer_outputs[0]
88
+ scores = []
89
+ rewards = self.v_head(hidden_states).squeeze(-1)
90
+ for i in range(bs):
91
+ c_inds = (input_ids[i] == self.PAD_ID).nonzero()
92
+ c_ind = c_inds[0].item() if len(c_inds) > 0 else input_ids.shape[1]
93
+ scores.append(rewards[i, c_ind - 1])
94
+ return scores
95
+ return scores
96
+
97
+ ## Load the model and tokenizer
98
+
99
+ reward_model = GPTRewardModel("meta-llama/Llama-2-7b-chat-hf", reward_tokenizer.eos_token_id)
100
+ reward_tokenizer = reward_model.tokenizer
101
+ reward_tokenizer.truncation_side = "left"
102
+
103
+ directory = snapshot_download("berkeley-nest/Starling-RM-7B-alpha")
104
+ for fpath in os.listdir(directory):
105
+ if fpath.endswith(".pt") or fpath.endswith("model.bin"):
106
+ checkpoint = os.path.join(directory, fpath)
107
+ break
108
+
109
+ reward_model.load_state_dict(torch.load(checkpoint), strict=False)
110
+ reward_model.eval().requires_grad_(False)
111
+
112
+
113
+ ## Define the reward function
114
+
115
+ def get_reward(samples):
116
+ """samples: List[str]"""
117
+ input_ids = []
118
+ attention_masks = []
119
+ encodings_dict = reward_tokenizer(
120
+ samples,
121
+ truncation=True,
122
+ max_length=2048,
123
+ padding="max_length",
124
+ return_tensors="pt",
125
+ ).to(reward_device)
126
+ input_ids = encodings_dict["input_ids"]
127
+ attention_masks = encodings_dict["attention_mask"]
128
+ mbs = reward_batch_size
129
+ out = []
130
+ for i in range(math.ceil(len(samples) / mbs)):
131
+ rewards = reward_model(input_ids=input_ids[i * mbs : (i + 1) * mbs], attention_mask=attention_masks[i * mbs : (i + 1) * mbs])
132
+ out.extend(rewards)
133
+ return torch.hstack(out)
134
+
135
+ ## Inference over test prompts with llama2 chat template
136
+
137
+ test_sample = ["<s>[INST] Hello? </s> [/INST] Hi, how can I help you?</s>"]
138
+ reward_for_test_sample = get_reward(test_sample)
139
+ print(reward_for_test_sample)
140
  ```
141
 
142