```python from torch.nn import nn BASE_MODEL = "CarperAI/stable-vicuna-13b-delta" RM_PATH = "vicuna-v0-rm.pt" class GPTRewardModel(nn.Module): def __init__(self): super().__init__() model = AutoModelForCausalLM.from_pretrained(BASE_MODEL) self.config = model.config self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd self.transformer = model.model self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) self.PAD_ID = self.tokenizer.pad_token_id def forward( self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, labels=None, return_dict=False, output_attentions=False, output_hidden_states=False, ): loss = None transformer_outputs = self.transformer( input_ids, attention_mask=attention_mask, ) hidden_states = transformer_outputs[0] rewards = self.v_head(hidden_states).squeeze(-1) end_scores = [] bs = input_ids.shape[0] loss = 0 inference = False for i in range(bs): c_inds = (input_ids[i] == self.PAD_ID).nonzero() c_ind = c_inds[0].item() if len(c_inds) > 0 else input_ids.shape[1] end_scores.append(rewards[i, c_ind - 1]) chosen_end_scores = torch.stack(end_scores) return {"end_scores": chosen_end_scores} rw_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) rw_tokenizer.padding_side = "right" rw_model = GPTRewardModel() rw_model.load_state_dict(torch.load(RM_PATH)['module']) rw_model.half() rw_model.eval() def get_scores(samples: List[str]): scores_list = [] batch_size = 2 for i in range(0, len(samples), batch_size): sub_samples = samples[i : i + batch_size] sub_samples = [chosen for chosen in sub_samples] encodings_dict = rw_tokenizer( sub_samples, truncation=True, max_length=config.train.seq_length, padding="max_length", return_tensors="pt", ) input_ids = encodings_dict["input_ids"].to(rw_device) attn_masks = encodings_dict["attention_mask"].to(rw_device) with torch.no_grad(): sub_scores = rw_model(input_ids=input_ids, attention_mask=attn_masks) scores_list.append(sub_scores["end_scores"]) scores = torch.cat(scores_list, dim=0) return scores ```