File size: 4,787 Bytes
40676ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ea974c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaForSequenceClassification
from sklearn.tree import DecisionTreeClassifier
import os
import pickle
import json
from huggingface_hub import hf_hub_download
from typing import List, Dict, Union
import numpy as np

def convert_to_chat_format(prompt, response=None):
    if "<extra_id_1>" in prompt:
        """
        Handling HelpSteer2 prompts which may contain multi-turn conversations with the special token <extra_id_1>
        """
        turns = prompt.split("<extra_id_1>")
        conversation = []
        conversation.append({
            "role": "user",
            "content": turns[0]
        })
        
        for i in range(1, len(turns)):
            parts = turns[i].split("\n", 1)
            role = parts[0]
            content = parts[1]
            conversation.append({
                "role": "assistant" if role == "Assistant" else "user",
                "content": content
            })
    else:
        conversation = [{"role": "user", "content": prompt}]
    if response is not None:
        conversation.append({"role": "assistant", "content": response})
    return conversation

def process_conversation(conversation):
    for message in conversation:
        message["content"] = message["content"].rstrip('\n')
    return conversation

class LlamaForDecisionTreeRewardModel(LlamaForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=True)
        # Initialize the decision tree
        self.tree = None
        # Define the default attributes (from HelpSteer2)
        self.attributes = ['helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity']
        print("Initialized LlamaForDecisionTreeRewardModel")
    
    def load_decision_tree(self, repo_id, filename="decision_tree.pkl"):
        # Load the tree from the model's directory
        with open(hf_hub_download(repo_id=repo_id, filename=filename), "rb") as f:
            self.tree = pickle.load(f)
            assert isinstance(self.tree, DecisionTreeClassifier), f"The tree is not a DecisionTreeClassifier. It is a {type(self.tree)}"
        with open(hf_hub_download(repo_id=repo_id, filename="config.json"), "r") as f:
            config = json.load(f)
        label2id_map = config["label2id"]
        # Sort labels and ids by ids
        labels, ids = zip(*sorted(label2id_map.items(), key=lambda x: x[1]))
        labels = list(labels)
        self.attributes = labels

    @torch.no_grad()
    def compare(self, prompt: Union[str, List[Dict[str, str]]], response_1: str, response_2: str, tokenizer, device):
        """
        Compare two inputs and return the difference in scores
        """
        assert self.tree is not None, "The decision tree is not loaded. Please call load_decision_tree(repo_id, filename) first."
        if isinstance(prompt, str):
            conversation = convert_to_chat_format(prompt)
        elif isinstance(prompt, list):
            conversation = prompt
        else:
            raise ValueError(f"The prompt must be a string or a list of dictionaries, but got {type(prompt)}")
        assert isinstance(conversation, list), "The conversation must be a list of dictionaries"
        assert len(conversation) >= 1, "The conversation must have at least one message (as prompt)"
        assert conversation[-1]["role"] == "user", "The last message in the conversation must be from the user"
        conversation_1 = conversation + [{"role": "assistant", "content": response_1}]
        conversation_2 = conversation + [{"role": "assistant", "content": response_2}]
        conversation_1 = process_conversation(conversation_1)
        conversation_2 = process_conversation(conversation_2)

        conv_tokenized_1 = tokenizer.apply_chat_template(conversation_1, tokenize=True, return_tensors="pt").to(device)
        conv_tokenized_2 = tokenizer.apply_chat_template(conversation_2, tokenize=True, return_tensors="pt").to(device)
        embedding_1 = self.forward(conv_tokenized_1, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy()
        embedding_2 = self.forward(conv_tokenized_2, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy()
        weight = self.score.weight.float().cpu().numpy()
        bias = self.score.bias.float().cpu().numpy()
        rewards_1 = embedding_1 @ weight.T + bias
        rewards_2 = embedding_2 @ weight.T + bias
        rewards_diff = rewards_2 - rewards_1
        return {
            "preference": self.tree.predict(rewards_diff)[0],
            "rewards": np.concatenate([rewards_1, rewards_2]),
            "attributes": self.attributes
            }