File size: 6,847 Bytes
bd73a7b
7e73556
bd73a7b
 
7e73556
 
bd73a7b
 
 
7e73556
bf9e30f
7e73556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd73a7b
 
 
 
7e73556
 
bf9e30f
7e73556
 
 
 
 
 
 
 
bf9e30f
7e73556
 
 
bd73a7b
7e73556
bd73a7b
7e73556
bd73a7b
7e73556
 
 
bd73a7b
bf9e30f
7e73556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf9e30f
7e73556
 
 
 
 
 
 
 
 
 
 
 
 
bf9e30f
7e73556
 
 
 
bd73a7b
7e73556
 
 
 
 
bd73a7b
7e73556
 
 
bf9e30f
7e73556
 
 
 
bd73a7b
 
 
 
 
7e73556
 
bd73a7b
 
 
7e73556
 
 
 
 
 
bd73a7b
 
7e73556
 
 
 
bd73a7b
7e73556
 
 
bd73a7b
7e73556
 
bf9e30f
7e73556
 
 
bd73a7b
 
 
7e73556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd73a7b
 
 
 
 
 
 
7e73556
 
 
 
 
 
 
bd73a7b
 
7e73556
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import logging
from collections import defaultdict

import numpy as np
import tiktoken

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def check_format_errors(train_dataset, user_role, model_role):
    """
    Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
    """
    # Format error checks
    format_errors = defaultdict(int)

    for ex in train_dataset:
        if not isinstance(ex, dict):
            format_errors["data_type"] += 1
            continue

        messages = ex.get("messages", None)
        if not messages:
            format_errors["missing_messages_list"] += 1
            continue

        for message in messages:
            if "role" not in message or "content" not in message:
                format_errors["message_missing_key"] += 1

            if any(
                k not in ("role", "content", "name", "function_call", "weight")
                for k in message
            ):
                format_errors["message_unrecognized_key"] += 1

            if message.get("role", None) not in ["system", user_role, model_role]:
                format_errors["unrecognized_role"] += 1

            content = message.get("content", None)
            function_call = message.get("function_call", None)

            if (not content and not function_call) or not isinstance(content, str):
                format_errors["missing_content"] += 1

        if not any(message.get("role", None) == model_role for message in messages):
            format_errors["example_missing_assistant_message"] += 1

    if format_errors:
        logger.warning("Found errors:")
        for k, v in format_errors.items():
            logger.warning(f"{k}: {v}")
    else:
        logger.info("No errors found")

    return format_errors if format_errors else {}


def get_distributions(train_dataset, user_role, model_role):
    """
    Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep

    Gets the distributions of the number of messages per example, the total number of tokens per example, and the number of assistant tokens per example.
    """
    encoding = tiktoken.get_encoding("cl100k_base")

    # not exact!
    # simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
    def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
        num_tokens = 0
        for message in messages:
            num_tokens += tokens_per_message
            for key, value in message.items():
                num_tokens += len(encoding.encode(value))
                if key == "name":
                    num_tokens += tokens_per_name
        num_tokens += 3
        return num_tokens

    def num_assistant_tokens_from_messages(messages):
        num_tokens = 0
        for message in messages:
            if message["role"] == model_role:
                num_tokens += len(encoding.encode(message["content"]))
        return num_tokens

    n_missing_system = 0
    n_missing_user = 0
    n_messages = []
    convo_lens = []
    assistant_message_lens = []

    for ex in train_dataset:
        messages = ex["messages"]
        if not any(message["role"] == "system" for message in messages):
            n_missing_system += 1
        if not any(message["role"] == user_role for message in messages):
            n_missing_user += 1
        n_messages.append(len(messages))
        convo_lens.append(num_tokens_from_messages(messages))
        assistant_message_lens.append(num_assistant_tokens_from_messages(messages))

    return {
        "n_missing_system": n_missing_system,
        "n_missing_user": n_missing_user,
        "n_messages": n_messages,
        "convo_lens": convo_lens,
        "assistant_message_lens": assistant_message_lens,
    }


def check_token_counts(train_dataset, user_role, model_role):
    """
    Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
    """

    def print_distribution(values, name):
        logger.info(f"\n#### Distribution of {name}:")
        logger.info(f"min / max: {min(values)}, {max(values)}")
        logger.info(f"mean / median: {np.mean(values)}, {np.median(values)}")
        logger.info(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

    # Warnings and tokens counts
    distributions = get_distributions(
        train_dataset, user_role=user_role, model_role=model_role
    )
    n_missing_system = distributions["n_missing_system"]
    n_missing_user = distributions["n_missing_user"]
    n_messages = distributions["n_messages"]
    convo_lens = distributions["convo_lens"]
    assistant_message_lens = distributions["assistant_message_lens"]

    logger.info("Num examples missing system message:", n_missing_system)
    logger.info("Num examples missing user message:", n_missing_user)
    print_distribution(n_messages, "num_messages_per_example")
    print_distribution(convo_lens, "num_total_tokens_per_example")
    print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
    n_too_long = sum(l > 4096 for l in convo_lens)
    logger.info(
        f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning"
    )

    return


def estimate_cost(train_dataset, user_role, model_role):
    """
    Extracted from: https://cookbook.openai.com/examples/chat_finetuning_data_prep
    """
    distributions = get_distributions(
        train_dataset, user_role=user_role, model_role=model_role
    )
    n_missing_system = distributions["n_missing_system"]
    n_missing_user = distributions["n_missing_user"]
    n_messages = distributions["n_messages"]
    convo_lens = distributions["convo_lens"]
    assistant_message_lens = distributions["assistant_message_lens"]

    # Pricing and default n_epochs estimate
    MAX_TOKENS_PER_EXAMPLE = 4096

    TARGET_EPOCHS = 3
    MIN_TARGET_EXAMPLES = 100
    MAX_TARGET_EXAMPLES = 25000
    MIN_DEFAULT_EPOCHS = 1
    MAX_DEFAULT_EPOCHS = 25

    n_epochs = TARGET_EPOCHS
    n_train_examples = len(train_dataset)
    try:
        if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
            n_epochs = min(MAX_DEFAULT_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
        elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
            n_epochs = max(MIN_DEFAULT_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
    except:
        n_epochs = TARGET_EPOCHS

    n_billing_tokens_in_dataset = sum(
        min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens
    )

    return {
        "Estimated number of tokens in dataset": n_billing_tokens_in_dataset,
        f"Estimated number of tokens that will be billed (assuming {n_epochs} training epochs)": n_epochs
        * n_billing_tokens_in_dataset,
    }