File size: 8,990 Bytes
4bdb245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import argparse
import re

from dataclasses import dataclass, field
from typing import List

# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp


@dataclass
class GptParams:
    seed: int = -1
    n_threads: int = min(4, os.cpu_count() or 1)
    n_predict: int = 128
    n_parts: int = -1
    n_ctx: int = 512
    n_batch: int = 8
    n_keep: int = 0

    ignore_eos: bool = False
    logit_bias: dict[int, float] = field(default_factory=dict)
    top_k: int = 40
    top_p: float = 0.95
    tfs_z: float = 1.00
    typical_p: float = 1.00
    temp: float = 0.80
    repeat_penalty: float = 1.10
    repeat_last_n: int = 64
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
    mirostat: int = 0
    mirostat_tau: float = 5.0
    mirostat_eta: float = 0.1

    model: str = "./models/llama-7B/ggml-model.bin"
    prompt: str = ""
    path_session: str = ""
    input_prefix: str = " "
    input_suffix: str = ""
    antiprompt: List[str] = field(default_factory=list)

    lora_adapter: str = ""
    lora_base: str = ""

    memory_f16: bool = True
    random_prompt: bool = False
    use_color: bool = False
    interactive: bool = False

    embedding: bool = False
    interactive_start: bool = False

    instruct: bool = False
    penalize_nl: bool = True
    perplexity: bool = False
    use_mmap: bool = True
    use_mlock: bool = False
    mem_test: bool = False
    verbose_prompt: bool = False

    file: str = None

    # If chat ended prematurely, append this to the conversation to fix it.
    # Set to "\nUser:" etc.
    # This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:""
    fix_prefix: str = ""
    input_echo: bool = True,

    # Default instructions for Alpaca
    # switch to "Human" and "Assistant" for Vicuna.
    # TODO: TBD how they are gonna handle this upstream
    instruct_inp_prefix: str="\n\n### Instruction:\n\n"
    instruct_inp_suffix: str="\n\n### Response:\n\n"


def gpt_params_parse(argv = None):
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-s", "--seed", type=int, default=-1, help="RNG seed (use random seed for <= 0)",dest="seed")
    parser.add_argument("-t", "--threads", type=int, default=min(4, os.cpu_count() or 1), help="number of threads to use during computation",dest="n_threads")
    parser.add_argument("-n", "--n_predict", type=int, default=128, help="number of tokens to predict (-1 = infinity)",dest="n_predict")
    parser.add_argument("--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts")
    parser.add_argument("-c", "--ctx_size", type=int, default=512, help="size of the prompt context",dest="n_ctx")
    parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size for prompt processing",dest="n_batch")
    parser.add_argument("--keep", type=int, default=0, help="number of tokens to keep from the initial prompt",dest="n_keep")

    parser.add_argument(
        "-l",
        "--logit-bias",
        type=str,
        action='append',
        help="--logit-bias TOKEN_ID(+/-)BIAS",
        dest="logit_bias_str"
    )
    parser.add_argument("--ignore-eos", action="store_true", help="ignore end of stream token and continue generating", dest="ignore_eos")
    parser.add_argument("--top_k", type=int, default=40, help="top-k sampling",dest="top_k")
    parser.add_argument("--top_p", type=float, default=0.95, help="top-p samplin",dest="top_p")
    parser.add_argument("--tfs", type=float, default=1.0, help="tail free sampling, parameter z (1.0 = disabled)",dest="tfs_z")
    parser.add_argument("--temp", type=float, default=0.80, help="temperature",dest="temp")
    parser.add_argument("--repeat_penalty", type=float, default=1.10, help="penalize repeat sequence of tokens",dest="repeat_penalty")
    parser.add_argument("--repeat_last_n", type=int, default=64, help="last n tokens to consider for penalize ",dest="repeat_last_n")
    parser.add_argument("--frequency_penalty", type=float, default=0.0, help="repeat alpha frequency penalty (0.0 = disabled)",dest="tfs_z")
    parser.add_argument("--presence_penalty", type=float, default=0.0, help="repeat alpha presence penalty (0.0 = disabled)",dest="presence_penalty")
    parser.add_argument("--mirostat", type=float, default=1.0, help="use Mirostat sampling.",dest="mirostat")
    parser.add_argument("--mirostat_ent", type=float, default=5.0, help="Mirostat target entropy, parameter tau represents the average surprise value",dest="mirostat_tau")
    parser.add_argument("--mirostat_lr", type=float, default=0.1, help="Mirostat learning rate, parameter eta",dest="mirostat_eta")

    parser.add_argument("-m", "--model", type=str, default="./models/llama-7B/ggml-model.bin", help="model path",dest="model")
    parser.add_argument("-p", "--prompt", type=str, default=None, help="initial prompt",dest="prompt")
    parser.add_argument("-f", "--file", type=str, default=None, help="file containing initial prompt to load",dest="file")
    parser.add_argument("--session", type=str, default=None, help="file to cache model state in (may be large!)",dest="path_session")
    parser.add_argument("--in-prefix", type=str, default="", help="string to prefix user inputs with", dest="input_prefix")
    parser.add_argument("--in-suffix", type=str, default="", help="append to input", dest="input_suffix")
    parser.add_argument(
        "-r",
        "--reverse-prompt",
        type=str,
        action='append',
        help="poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).",
        dest="antiprompt"
    )
    
    parser.add_argument("--lora", type=str, default="", help="apply LoRA adapter (implies --no-mmap)", dest="lora_adapter")
    parser.add_argument("--lora-base", type=str, default="", help="optional model to use as a base for the layers modified by the LoRA adapter", dest="lora_base")

    parser.add_argument("--memory_f32", action="store_false", help="use f32 instead of f16 for memory key+value",dest="memory_f16")
    parser.add_argument("--random-prompt", action="store_true", help="start with a randomized prompt.", dest="random_prompt")
    parser.add_argument(
        "--color",
        action="store_true",
        help="colorise output to distinguish prompt and user input from generations",
        dest="use_color"
    )
    parser.add_argument(
        "-i", "--interactive", action="store_true", help="run in interactive mode", dest="interactive"
    )
    
    parser.add_argument("--embedding", action="store_true", help="", dest="embedding")
    parser.add_argument(
        "--interactive-first",
        action="store_true",
        help="run in interactive mode and wait for input right away",
        dest="interactive_start"
    )

    parser.add_argument(
        "-ins",
        "--instruct",
        action="store_true",
        help="run in instruction mode (use with Alpaca or Vicuna models)",
        dest="instruct"
    )
    parser.add_argument("--no-penalize-nl", action="store_false", help="do not penalize newline token", dest="penalize_nl")
    parser.add_argument("--perplexity", action="store_true", help="compute perplexity over the prompt", dest="perplexity")
    parser.add_argument("--no-mmap", action="store_false",help="do not memory-map model (slower load but may reduce pageouts if not using mlock)",dest="use_mmap")
    parser.add_argument("--mlock", action="store_true",help="force system to keep model in RAM rather than swapping or compressing",dest="use_mlock")
    parser.add_argument("--mtest", action="store_true",help="compute maximum memory usage",dest="mem_test")
    parser.add_argument("--verbose-prompt", action="store_true",help="print prompt before generation",dest="verbose_prompt")

    #Custom args
    parser.add_argument("--fix-prefix", type=str, default="", help="append to input when generated n_predict tokens", dest="fix_prefix")
    parser.add_argument("--input-noecho", action="store_false", help="dont output the input", dest="input_echo")

    parser.add_argument(
        "--interactive-start",
        action="store_true",
        help="run in interactive mode",
        dest="interactive"
    )

    args = parser.parse_args(argv)
    
    logit_bias_str = args.logit_bias_str
    delattr(args, "logit_bias_str")
    params = GptParams(**vars(args))

    if (params.lora_adapter):
        params.use_mmap = False

    if (logit_bias_str != None):
        for i in logit_bias_str:
            if (m := re.match(r"(\d+)([-+]\d+)", i)):
                params.logit_bias[int(m.group(1))] = float(m.group(2))

    return params

def gpt_random_prompt(rng):
    return [
        "So",
        "Once upon a time",
        "When",
        "The",
        "After",
        "If",
        "import",
        "He",
        "She",
        "They",
    ][rng % 10]

if __name__ == "__main__":
    print(gpt_params_parse())