Upload 4 files
Browse files- .gitattributes +1 -0
- ggml-common.h +0 -0
- ggml-metal.metal +0 -0
- main +3 -0
- main.cpp +954 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
main filter=lfs diff=lfs merge=lfs -text
|
ggml-common.h
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ggml-metal.metal
ADDED
The diff for this file is too large to render.
See raw diff
|
|
main
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:349cf33b64ee47606d53bfca53b64b0779f4c6b92c81aeccf04d793265e265fd
|
3 |
+
size 1676065
|
main.cpp
ADDED
@@ -0,0 +1,954 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "common.h"
|
2 |
+
|
3 |
+
#include "console.h"
|
4 |
+
#include "llama.h"
|
5 |
+
|
6 |
+
#include <cassert>
|
7 |
+
#include <cinttypes>
|
8 |
+
#include <cmath>
|
9 |
+
#include <cstdio>
|
10 |
+
#include <cstring>
|
11 |
+
#include <ctime>
|
12 |
+
#include <fstream>
|
13 |
+
#include <iostream>
|
14 |
+
#include <sstream>
|
15 |
+
#include <string>
|
16 |
+
#include <vector>
|
17 |
+
|
18 |
+
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
19 |
+
#include <signal.h>
|
20 |
+
#include <unistd.h>
|
21 |
+
#elif defined (_WIN32)
|
22 |
+
#define WIN32_LEAN_AND_MEAN
|
23 |
+
#ifndef NOMINMAX
|
24 |
+
#define NOMINMAX
|
25 |
+
#endif
|
26 |
+
#include <windows.h>
|
27 |
+
#include <signal.h>
|
28 |
+
#endif
|
29 |
+
|
30 |
+
#if defined(_MSC_VER)
|
31 |
+
#pragma warning(disable: 4244 4267) // possible loss of data
|
32 |
+
#endif
|
33 |
+
|
34 |
+
static llama_context ** g_ctx;
|
35 |
+
static llama_model ** g_model;
|
36 |
+
static gpt_params * g_params;
|
37 |
+
static std::vector<llama_token> * g_input_tokens;
|
38 |
+
static std::ostringstream * g_output_ss;
|
39 |
+
static std::vector<llama_token> * g_output_tokens;
|
40 |
+
static bool is_interacting = false;
|
41 |
+
|
42 |
+
static bool file_exists(const std::string &path) {
|
43 |
+
std::ifstream f(path.c_str());
|
44 |
+
return f.good();
|
45 |
+
}
|
46 |
+
|
47 |
+
static bool file_is_empty(const std::string &path) {
|
48 |
+
std::ifstream f;
|
49 |
+
f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
|
50 |
+
f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
|
51 |
+
return f.tellg() == 0;
|
52 |
+
}
|
53 |
+
|
54 |
+
static void write_logfile(
|
55 |
+
const llama_context * ctx, const gpt_params & params, const llama_model * model,
|
56 |
+
const std::vector<llama_token> & input_tokens, const std::string & output,
|
57 |
+
const std::vector<llama_token> & output_tokens
|
58 |
+
) {
|
59 |
+
if (params.logdir.empty()) {
|
60 |
+
return;
|
61 |
+
}
|
62 |
+
|
63 |
+
const std::string timestamp = get_sortable_timestamp();
|
64 |
+
|
65 |
+
const bool success = create_directory_with_parents(params.logdir);
|
66 |
+
if (!success) {
|
67 |
+
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
|
68 |
+
__func__, params.logdir.c_str());
|
69 |
+
return;
|
70 |
+
}
|
71 |
+
|
72 |
+
const std::string logfile_path = params.logdir + timestamp + ".yml";
|
73 |
+
FILE * logfile = fopen(logfile_path.c_str(), "w");
|
74 |
+
|
75 |
+
if (logfile == NULL) {
|
76 |
+
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
|
77 |
+
return;
|
78 |
+
}
|
79 |
+
|
80 |
+
fprintf(logfile, "binary: main\n");
|
81 |
+
char model_desc[128];
|
82 |
+
llama_model_desc(model, model_desc, sizeof(model_desc));
|
83 |
+
dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc);
|
84 |
+
|
85 |
+
fprintf(logfile, "\n");
|
86 |
+
fprintf(logfile, "######################\n");
|
87 |
+
fprintf(logfile, "# Generation Results #\n");
|
88 |
+
fprintf(logfile, "######################\n");
|
89 |
+
fprintf(logfile, "\n");
|
90 |
+
|
91 |
+
dump_string_yaml_multiline(logfile, "output", output.c_str());
|
92 |
+
dump_vector_int_yaml(logfile, "output_tokens", output_tokens);
|
93 |
+
|
94 |
+
llama_dump_timing_info_yaml(logfile, ctx);
|
95 |
+
fclose(logfile);
|
96 |
+
}
|
97 |
+
|
98 |
+
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
99 |
+
static void sigint_handler(int signo) {
|
100 |
+
if (signo == SIGINT) {
|
101 |
+
if (!is_interacting && g_params->interactive) {
|
102 |
+
is_interacting = true;
|
103 |
+
} else {
|
104 |
+
console::cleanup();
|
105 |
+
printf("\n");
|
106 |
+
// llama_print_timings(*g_ctx);
|
107 |
+
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
108 |
+
_exit(130);
|
109 |
+
}
|
110 |
+
}
|
111 |
+
}
|
112 |
+
#endif
|
113 |
+
|
114 |
+
static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) {
|
115 |
+
(void) level;
|
116 |
+
(void) user_data;
|
117 |
+
// LOG_TEE("%s", text);
|
118 |
+
}
|
119 |
+
|
120 |
+
int main(int argc, char ** argv) {
|
121 |
+
gpt_params params;
|
122 |
+
g_params = ¶ms;
|
123 |
+
|
124 |
+
if (!gpt_params_parse(argc, argv, params)) {
|
125 |
+
return 1;
|
126 |
+
}
|
127 |
+
llama_sampling_params & sparams = params.sparams;
|
128 |
+
|
129 |
+
#ifndef LOG_DISABLE_LOGS
|
130 |
+
log_set_target(log_filename_generator("main", "log"));
|
131 |
+
// LOG_TEE("Log start\n");
|
132 |
+
log_dump_cmdline(argc, argv);
|
133 |
+
llama_log_set(llama_log_callback_logTee, nullptr);
|
134 |
+
#endif // LOG_DISABLE_LOGS
|
135 |
+
|
136 |
+
// TODO: Dump params ?
|
137 |
+
//LOG("Params perplexity: %s\n", LOG_TOSTR(params.perplexity));
|
138 |
+
|
139 |
+
// save choice to use color for later
|
140 |
+
// (note for later: this is a slightly awkward choice)
|
141 |
+
console::init(params.simple_io, params.use_color);
|
142 |
+
atexit([]() { console::cleanup(); });
|
143 |
+
|
144 |
+
if (params.logits_all) {
|
145 |
+
printf("\n************\n");
|
146 |
+
printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
|
147 |
+
printf("************\n\n");
|
148 |
+
|
149 |
+
return 0;
|
150 |
+
}
|
151 |
+
|
152 |
+
if (params.embedding) {
|
153 |
+
printf("\n************\n");
|
154 |
+
printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
|
155 |
+
printf("************\n\n");
|
156 |
+
|
157 |
+
return 0;
|
158 |
+
}
|
159 |
+
|
160 |
+
if (params.n_ctx != 0 && params.n_ctx < 8) {
|
161 |
+
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
|
162 |
+
params.n_ctx = 8;
|
163 |
+
}
|
164 |
+
|
165 |
+
if (params.rope_freq_base != 0.0) {
|
166 |
+
LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base);
|
167 |
+
}
|
168 |
+
|
169 |
+
if (params.rope_freq_scale != 0.0) {
|
170 |
+
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
|
171 |
+
}
|
172 |
+
|
173 |
+
// LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
174 |
+
// LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
175 |
+
|
176 |
+
if (params.seed == LLAMA_DEFAULT_SEED) {
|
177 |
+
params.seed = time(NULL);
|
178 |
+
}
|
179 |
+
|
180 |
+
// LOG_TEE("%s: seed = %u\n", __func__, params.seed);
|
181 |
+
|
182 |
+
std::mt19937 rng(params.seed);
|
183 |
+
if (params.random_prompt) {
|
184 |
+
params.prompt = gpt_random_prompt(rng);
|
185 |
+
}
|
186 |
+
|
187 |
+
LOG("%s: llama backend init\n", __func__);
|
188 |
+
llama_backend_init();
|
189 |
+
llama_numa_init(params.numa);
|
190 |
+
|
191 |
+
llama_model * model;
|
192 |
+
llama_context * ctx;
|
193 |
+
llama_context * ctx_guidance = NULL;
|
194 |
+
g_model = &model;
|
195 |
+
g_ctx = &ctx;
|
196 |
+
|
197 |
+
// load the model and apply lora adapter, if any
|
198 |
+
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
199 |
+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
200 |
+
if (sparams.cfg_scale > 1.f) {
|
201 |
+
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
|
202 |
+
ctx_guidance = llama_new_context_with_model(model, lparams);
|
203 |
+
}
|
204 |
+
|
205 |
+
if (model == NULL) {
|
206 |
+
LOG_TEE("%s: error: unable to load model\n", __func__);
|
207 |
+
return 1;
|
208 |
+
}
|
209 |
+
|
210 |
+
const int n_ctx_train = llama_n_ctx_train(model);
|
211 |
+
const int n_ctx = llama_n_ctx(ctx);
|
212 |
+
// LOG("n_ctx: %d\n", n_ctx);
|
213 |
+
|
214 |
+
if (n_ctx > n_ctx_train) {
|
215 |
+
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
216 |
+
__func__, n_ctx_train, n_ctx);
|
217 |
+
}
|
218 |
+
|
219 |
+
// print system information
|
220 |
+
// {
|
221 |
+
// LOG_TEE("\n");
|
222 |
+
// LOG_TEE("%s\n", get_system_info(params).c_str());
|
223 |
+
// }
|
224 |
+
|
225 |
+
std::string path_session = params.path_prompt_cache;
|
226 |
+
std::vector<llama_token> session_tokens;
|
227 |
+
|
228 |
+
if (!path_session.empty()) {
|
229 |
+
// LOG_TEE("%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str());
|
230 |
+
if (!file_exists(path_session)) {
|
231 |
+
// LOG_TEE("%s: session file does not exist, will create.\n", __func__);
|
232 |
+
} else if (file_is_empty(path_session)) {
|
233 |
+
// LOG_TEE("%s: The session file is empty. A new session will be initialized.\n", __func__);
|
234 |
+
} else {
|
235 |
+
// The file exists and is not empty
|
236 |
+
session_tokens.resize(n_ctx);
|
237 |
+
size_t n_token_count_out = 0;
|
238 |
+
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
|
239 |
+
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
|
240 |
+
return 1;
|
241 |
+
}
|
242 |
+
session_tokens.resize(n_token_count_out);
|
243 |
+
llama_set_rng_seed(ctx, params.seed);
|
244 |
+
// LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
|
245 |
+
}
|
246 |
+
}
|
247 |
+
|
248 |
+
const bool add_bos = llama_should_add_bos_token(model);
|
249 |
+
// LOG("add_bos: %d\n", add_bos);
|
250 |
+
|
251 |
+
std::vector<llama_token> embd_inp;
|
252 |
+
|
253 |
+
if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
|
254 |
+
LOG("tokenize the prompt\n");
|
255 |
+
if (params.chatml) {
|
256 |
+
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
|
257 |
+
}
|
258 |
+
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
259 |
+
} else {
|
260 |
+
LOG("use session tokens\n");
|
261 |
+
embd_inp = session_tokens;
|
262 |
+
}
|
263 |
+
|
264 |
+
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
|
265 |
+
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
266 |
+
|
267 |
+
// Should not run without any tokens
|
268 |
+
if (embd_inp.empty()) {
|
269 |
+
embd_inp.push_back(llama_token_bos(model));
|
270 |
+
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
271 |
+
}
|
272 |
+
|
273 |
+
// Tokenize negative prompt
|
274 |
+
std::vector<llama_token> guidance_inp;
|
275 |
+
int guidance_offset = 0;
|
276 |
+
int original_prompt_len = 0;
|
277 |
+
if (ctx_guidance) {
|
278 |
+
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
|
279 |
+
|
280 |
+
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
|
281 |
+
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
|
282 |
+
|
283 |
+
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
284 |
+
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
285 |
+
|
286 |
+
original_prompt_len = original_inp.size();
|
287 |
+
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
288 |
+
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
289 |
+
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
290 |
+
}
|
291 |
+
|
292 |
+
if ((int) embd_inp.size() > n_ctx - 4) {
|
293 |
+
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
|
294 |
+
return 1;
|
295 |
+
}
|
296 |
+
|
297 |
+
// debug message about similarity of saved session, if applicable
|
298 |
+
size_t n_matching_session_tokens = 0;
|
299 |
+
if (!session_tokens.empty()) {
|
300 |
+
for (llama_token id : session_tokens) {
|
301 |
+
if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
|
302 |
+
break;
|
303 |
+
}
|
304 |
+
n_matching_session_tokens++;
|
305 |
+
}
|
306 |
+
// if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
|
307 |
+
// LOG_TEE("%s: using full prompt from session file\n", __func__);
|
308 |
+
// } else if (n_matching_session_tokens >= embd_inp.size()) {
|
309 |
+
// LOG_TEE("%s: session file has exact match for prompt!\n", __func__);
|
310 |
+
// } else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
|
311 |
+
// LOG_TEE("%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
|
312 |
+
// __func__, n_matching_session_tokens, embd_inp.size());
|
313 |
+
// } else {
|
314 |
+
// LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n",
|
315 |
+
// __func__, n_matching_session_tokens, embd_inp.size());
|
316 |
+
// }
|
317 |
+
|
318 |
+
// remove any "future" tokens that we might have inherited from the previous session
|
319 |
+
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
320 |
+
}
|
321 |
+
|
322 |
+
LOGLN(
|
323 |
+
"recalculate the cached logits (check): embd_inp.empty() %s, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu, embd_inp.size() %zu",
|
324 |
+
log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size(), embd_inp.size());
|
325 |
+
|
326 |
+
// if we will use the cache for the full prompt without reaching the end of the cache, force
|
327 |
+
// reevaluation of the last token token to recalculate the cached logits
|
328 |
+
if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
|
329 |
+
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
|
330 |
+
|
331 |
+
session_tokens.resize(embd_inp.size() - 1);
|
332 |
+
}
|
333 |
+
|
334 |
+
// number of tokens to keep when resetting context
|
335 |
+
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
|
336 |
+
params.n_keep = (int)embd_inp.size();
|
337 |
+
} else {
|
338 |
+
params.n_keep += add_bos; // always keep the BOS token
|
339 |
+
}
|
340 |
+
|
341 |
+
// prefix & suffix for instruct mode
|
342 |
+
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
|
343 |
+
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
|
344 |
+
|
345 |
+
// LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
|
346 |
+
// LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
|
347 |
+
|
348 |
+
// chatml prefix & suffix
|
349 |
+
const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true);
|
350 |
+
const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true);
|
351 |
+
|
352 |
+
// LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str());
|
353 |
+
// LOG("cml_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_sfx).c_str());
|
354 |
+
|
355 |
+
// in instruct mode, we inject a prefix and a suffix to each input by the user
|
356 |
+
if (params.instruct) {
|
357 |
+
params.interactive_first = true;
|
358 |
+
params.antiprompt.emplace_back("### Instruction:\n\n");
|
359 |
+
}
|
360 |
+
// similar for chatml mode
|
361 |
+
else if (params.chatml) {
|
362 |
+
params.interactive_first = true;
|
363 |
+
params.antiprompt.emplace_back("<|im_start|>user\n");
|
364 |
+
}
|
365 |
+
|
366 |
+
// enable interactive mode if interactive start is specified
|
367 |
+
if (params.interactive_first) {
|
368 |
+
params.interactive = true;
|
369 |
+
}
|
370 |
+
|
371 |
+
// if (params.verbose_prompt) {
|
372 |
+
// LOG_TEE("\n");
|
373 |
+
// LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
374 |
+
// LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
375 |
+
// for (int i = 0; i < (int) embd_inp.size(); i++) {
|
376 |
+
// LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
377 |
+
// }
|
378 |
+
|
379 |
+
// if (ctx_guidance) {
|
380 |
+
// LOG_TEE("\n");
|
381 |
+
// LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
|
382 |
+
// LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
|
383 |
+
// for (int i = 0; i < (int) guidance_inp.size(); i++) {
|
384 |
+
// LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
|
385 |
+
// }
|
386 |
+
// }
|
387 |
+
|
388 |
+
// if (params.n_keep > add_bos) {
|
389 |
+
// LOG_TEE("%s: static prompt based on n_keep: '", __func__);
|
390 |
+
// for (int i = 0; i < params.n_keep; i++) {
|
391 |
+
// LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
392 |
+
// }
|
393 |
+
// LOG_TEE("'\n");
|
394 |
+
// }
|
395 |
+
// LOG_TEE("\n");
|
396 |
+
// }
|
397 |
+
|
398 |
+
// ctrl+C handling
|
399 |
+
{
|
400 |
+
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
401 |
+
struct sigaction sigint_action;
|
402 |
+
sigint_action.sa_handler = sigint_handler;
|
403 |
+
sigemptyset (&sigint_action.sa_mask);
|
404 |
+
sigint_action.sa_flags = 0;
|
405 |
+
sigaction(SIGINT, &sigint_action, NULL);
|
406 |
+
#elif defined (_WIN32)
|
407 |
+
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
408 |
+
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
|
409 |
+
};
|
410 |
+
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
411 |
+
#endif
|
412 |
+
}
|
413 |
+
|
414 |
+
if (params.interactive) {
|
415 |
+
// LOG_TEE("%s: interactive mode on.\n", __func__);
|
416 |
+
|
417 |
+
if (!params.antiprompt.empty()) {
|
418 |
+
for (const auto & antiprompt : params.antiprompt) {
|
419 |
+
// LOG_TEE("Reverse prompt: '%s'\n", antiprompt.c_str());
|
420 |
+
if (params.verbose_prompt) {
|
421 |
+
auto tmp = ::llama_tokenize(ctx, antiprompt, false, true);
|
422 |
+
for (int i = 0; i < (int) tmp.size(); i++) {
|
423 |
+
// LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
|
424 |
+
}
|
425 |
+
}
|
426 |
+
}
|
427 |
+
}
|
428 |
+
|
429 |
+
// if (params.input_prefix_bos) {
|
430 |
+
// LOG_TEE("Input prefix with BOS\n");
|
431 |
+
// }
|
432 |
+
|
433 |
+
// if (!params.input_prefix.empty()) {
|
434 |
+
// LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str());
|
435 |
+
// if (params.verbose_prompt) {
|
436 |
+
// auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true);
|
437 |
+
// for (int i = 0; i < (int) tmp.size(); i++) {
|
438 |
+
// LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
|
439 |
+
// }
|
440 |
+
// }
|
441 |
+
// }
|
442 |
+
|
443 |
+
// if (!params.input_suffix.empty()) {
|
444 |
+
// LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
|
445 |
+
// if (params.verbose_prompt) {
|
446 |
+
// auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true);
|
447 |
+
// for (int i = 0; i < (int) tmp.size(); i++) {
|
448 |
+
// LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
|
449 |
+
// }
|
450 |
+
// }
|
451 |
+
// }
|
452 |
+
}
|
453 |
+
// LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
454 |
+
// LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
|
455 |
+
// LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
456 |
+
|
457 |
+
// group-attention state
|
458 |
+
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
|
459 |
+
int ga_i = 0;
|
460 |
+
|
461 |
+
const int ga_n = params.grp_attn_n;
|
462 |
+
const int ga_w = params.grp_attn_w;
|
463 |
+
|
464 |
+
if (ga_n != 1) {
|
465 |
+
GGML_ASSERT(ga_n > 0 && "grp_attn_n must be positive"); // NOLINT
|
466 |
+
GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT
|
467 |
+
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT
|
468 |
+
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
|
469 |
+
// LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
|
470 |
+
}
|
471 |
+
// LOG_TEE("\n\n");
|
472 |
+
|
473 |
+
if (params.interactive) {
|
474 |
+
const char *control_message;
|
475 |
+
if (params.multiline_input) {
|
476 |
+
control_message = " - To return control to LLaMa, end your input with '\\'.\n"
|
477 |
+
" - To return control without starting a new line, end your input with '/'.\n";
|
478 |
+
} else {
|
479 |
+
control_message = " - Press Return to return control to LLaMa.\n"
|
480 |
+
" - To return control without starting a new line, end your input with '/'.\n"
|
481 |
+
" - If you want to submit another line, end your input with '\\'.\n";
|
482 |
+
}
|
483 |
+
// LOG_TEE("== Running in interactive mode. ==\n");
|
484 |
+
// #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
485 |
+
// // LOG_TEE( " - Press Ctrl+C to interject at any time.\n");
|
486 |
+
// #endif
|
487 |
+
// LOG_TEE( "%s\n", control_message);
|
488 |
+
|
489 |
+
is_interacting = params.interactive_first;
|
490 |
+
}
|
491 |
+
|
492 |
+
bool is_antiprompt = false;
|
493 |
+
bool input_echo = true;
|
494 |
+
bool display = true;
|
495 |
+
bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
|
496 |
+
|
497 |
+
int n_past = 0;
|
498 |
+
int n_remain = params.n_predict;
|
499 |
+
int n_consumed = 0;
|
500 |
+
int n_session_consumed = 0;
|
501 |
+
int n_past_guidance = 0;
|
502 |
+
|
503 |
+
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
504 |
+
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
505 |
+
std::ostringstream output_ss; g_output_ss = &output_ss;
|
506 |
+
|
507 |
+
// the first thing we will do is to output the prompt, so set color accordingly
|
508 |
+
console::set_display(console::prompt);
|
509 |
+
display = params.display_prompt;
|
510 |
+
|
511 |
+
std::vector<llama_token> embd;
|
512 |
+
std::vector<llama_token> embd_guidance;
|
513 |
+
|
514 |
+
// tokenized antiprompts
|
515 |
+
std::vector<std::vector<llama_token>> antiprompt_ids;
|
516 |
+
|
517 |
+
antiprompt_ids.reserve(params.antiprompt.size());
|
518 |
+
for (const std::string & antiprompt : params.antiprompt) {
|
519 |
+
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
520 |
+
}
|
521 |
+
|
522 |
+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
523 |
+
|
524 |
+
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
525 |
+
// predict
|
526 |
+
if (!embd.empty()) {
|
527 |
+
// Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
|
528 |
+
// --prompt or --file which uses the same value.
|
529 |
+
int max_embd_size = n_ctx - 4;
|
530 |
+
|
531 |
+
// Ensure the input doesn't exceed the context size by truncating embd if necessary.
|
532 |
+
if ((int) embd.size() > max_embd_size) {
|
533 |
+
const int skipped_tokens = (int) embd.size() - max_embd_size;
|
534 |
+
embd.resize(max_embd_size);
|
535 |
+
|
536 |
+
console::set_display(console::error);
|
537 |
+
printf("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
|
538 |
+
console::set_display(console::reset);
|
539 |
+
fflush(stdout);
|
540 |
+
}
|
541 |
+
|
542 |
+
if (ga_n == 1) {
|
543 |
+
// infinite text generation via context shifting
|
544 |
+
// if we run out of context:
|
545 |
+
// - take the n_keep first tokens from the original prompt (via n_past)
|
546 |
+
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
547 |
+
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
|
548 |
+
if (params.n_predict == -2) {
|
549 |
+
// LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
550 |
+
break;
|
551 |
+
}
|
552 |
+
|
553 |
+
const int n_left = n_past - params.n_keep;
|
554 |
+
const int n_discard = n_left/2;
|
555 |
+
|
556 |
+
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
557 |
+
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
558 |
+
|
559 |
+
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
560 |
+
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
561 |
+
|
562 |
+
n_past -= n_discard;
|
563 |
+
|
564 |
+
if (ctx_guidance) {
|
565 |
+
n_past_guidance -= n_discard;
|
566 |
+
}
|
567 |
+
|
568 |
+
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
|
569 |
+
|
570 |
+
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
571 |
+
|
572 |
+
LOG("clear session path\n");
|
573 |
+
path_session.clear();
|
574 |
+
}
|
575 |
+
} else {
|
576 |
+
// context extension via Self-Extend
|
577 |
+
while (n_past >= ga_i + ga_w) {
|
578 |
+
const int ib = (ga_n*ga_i)/ga_w;
|
579 |
+
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
580 |
+
const int dd = (ga_w/ga_n) - ib*bd - ga_w;
|
581 |
+
|
582 |
+
LOG("\n");
|
583 |
+
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd);
|
584 |
+
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
|
585 |
+
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
|
586 |
+
|
587 |
+
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
|
588 |
+
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
589 |
+
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
590 |
+
|
591 |
+
n_past -= bd;
|
592 |
+
|
593 |
+
ga_i += ga_w/ga_n;
|
594 |
+
|
595 |
+
LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i);
|
596 |
+
}
|
597 |
+
}
|
598 |
+
|
599 |
+
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
|
600 |
+
if (n_session_consumed < (int) session_tokens.size()) {
|
601 |
+
size_t i = 0;
|
602 |
+
for ( ; i < embd.size(); i++) {
|
603 |
+
if (embd[i] != session_tokens[n_session_consumed]) {
|
604 |
+
session_tokens.resize(n_session_consumed);
|
605 |
+
break;
|
606 |
+
}
|
607 |
+
|
608 |
+
n_past++;
|
609 |
+
n_session_consumed++;
|
610 |
+
|
611 |
+
if (n_session_consumed >= (int) session_tokens.size()) {
|
612 |
+
++i;
|
613 |
+
break;
|
614 |
+
}
|
615 |
+
}
|
616 |
+
if (i > 0) {
|
617 |
+
embd.erase(embd.begin(), embd.begin() + i);
|
618 |
+
}
|
619 |
+
}
|
620 |
+
|
621 |
+
// evaluate tokens in batches
|
622 |
+
// embd is typically prepared beforehand to fit within a batch, but not always
|
623 |
+
if (ctx_guidance) {
|
624 |
+
int input_size = 0;
|
625 |
+
llama_token * input_buf = NULL;
|
626 |
+
|
627 |
+
if (n_past_guidance < (int) guidance_inp.size()) {
|
628 |
+
// Guidance context should have the same data with these modifications:
|
629 |
+
//
|
630 |
+
// * Replace the initial prompt
|
631 |
+
// * Shift everything by guidance_offset
|
632 |
+
embd_guidance = guidance_inp;
|
633 |
+
if (embd.begin() + original_prompt_len < embd.end()) {
|
634 |
+
embd_guidance.insert(
|
635 |
+
embd_guidance.end(),
|
636 |
+
embd.begin() + original_prompt_len,
|
637 |
+
embd.end()
|
638 |
+
);
|
639 |
+
}
|
640 |
+
|
641 |
+
input_buf = embd_guidance.data();
|
642 |
+
input_size = embd_guidance.size();
|
643 |
+
|
644 |
+
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
|
645 |
+
} else {
|
646 |
+
input_buf = embd.data();
|
647 |
+
input_size = embd.size();
|
648 |
+
}
|
649 |
+
|
650 |
+
for (int i = 0; i < input_size; i += params.n_batch) {
|
651 |
+
int n_eval = std::min(input_size - i, params.n_batch);
|
652 |
+
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
|
653 |
+
// LOG_TEE("%s : failed to eval\n", __func__);
|
654 |
+
return 1;
|
655 |
+
}
|
656 |
+
|
657 |
+
n_past_guidance += n_eval;
|
658 |
+
}
|
659 |
+
}
|
660 |
+
|
661 |
+
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
|
662 |
+
int n_eval = (int) embd.size() - i;
|
663 |
+
if (n_eval > params.n_batch) {
|
664 |
+
n_eval = params.n_batch;
|
665 |
+
}
|
666 |
+
|
667 |
+
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
668 |
+
|
669 |
+
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
|
670 |
+
LOG_TEE("%s : failed to eval\n", __func__);
|
671 |
+
return 1;
|
672 |
+
}
|
673 |
+
|
674 |
+
n_past += n_eval;
|
675 |
+
|
676 |
+
// LOG("n_past = %d\n", n_past);
|
677 |
+
// Display total tokens alongside total time
|
678 |
+
// if (params.n_print > 0 && n_past % params.n_print == 0) {
|
679 |
+
// LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
|
680 |
+
// }
|
681 |
+
}
|
682 |
+
|
683 |
+
if (!embd.empty() && !path_session.empty()) {
|
684 |
+
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
685 |
+
n_session_consumed = session_tokens.size();
|
686 |
+
}
|
687 |
+
}
|
688 |
+
|
689 |
+
embd.clear();
|
690 |
+
embd_guidance.clear();
|
691 |
+
|
692 |
+
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
693 |
+
// optionally save the session on first sample (for faster prompt loading next time)
|
694 |
+
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
695 |
+
need_to_save_session = false;
|
696 |
+
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
697 |
+
|
698 |
+
LOG("saved session to %s\n", path_session.c_str());
|
699 |
+
}
|
700 |
+
|
701 |
+
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
702 |
+
|
703 |
+
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
704 |
+
|
705 |
+
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
706 |
+
|
707 |
+
embd.push_back(id);
|
708 |
+
|
709 |
+
// echo this to console
|
710 |
+
input_echo = true;
|
711 |
+
|
712 |
+
// decrement remaining sampling budget
|
713 |
+
--n_remain;
|
714 |
+
|
715 |
+
LOG("n_remain: %d\n", n_remain);
|
716 |
+
} else {
|
717 |
+
// some user input remains from prompt or interaction, forward it to processing
|
718 |
+
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
719 |
+
while ((int) embd_inp.size() > n_consumed) {
|
720 |
+
embd.push_back(embd_inp[n_consumed]);
|
721 |
+
|
722 |
+
// push the prompt in the sampling context in order to apply repetition penalties later
|
723 |
+
// for the prompt, we don't apply grammar rules
|
724 |
+
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
|
725 |
+
|
726 |
+
++n_consumed;
|
727 |
+
if ((int) embd.size() >= params.n_batch) {
|
728 |
+
break;
|
729 |
+
}
|
730 |
+
}
|
731 |
+
}
|
732 |
+
|
733 |
+
// display text
|
734 |
+
if (input_echo && display) {
|
735 |
+
for (auto id : embd) {
|
736 |
+
const std::string token_str = llama_token_to_piece(ctx, id);
|
737 |
+
printf("%s", token_str.c_str());
|
738 |
+
|
739 |
+
if (embd.size() > 1) {
|
740 |
+
input_tokens.push_back(id);
|
741 |
+
} else {
|
742 |
+
output_tokens.push_back(id);
|
743 |
+
output_ss << token_str;
|
744 |
+
}
|
745 |
+
}
|
746 |
+
fflush(stdout);
|
747 |
+
}
|
748 |
+
// reset color to default if there is no pending user input
|
749 |
+
if (input_echo && (int) embd_inp.size() == n_consumed) {
|
750 |
+
console::set_display(console::reset);
|
751 |
+
display = true;
|
752 |
+
}
|
753 |
+
|
754 |
+
// if not currently processing queued inputs;
|
755 |
+
if ((int) embd_inp.size() <= n_consumed) {
|
756 |
+
// check for reverse prompt in the last n_prev tokens
|
757 |
+
if (!params.antiprompt.empty()) {
|
758 |
+
const int n_prev = 32;
|
759 |
+
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
|
760 |
+
|
761 |
+
is_antiprompt = false;
|
762 |
+
// Check if each of the reverse prompts appears at the end of the output.
|
763 |
+
// If we're not running interactively, the reverse prompt might be tokenized with some following characters
|
764 |
+
// so we'll compensate for that by widening the search window a bit.
|
765 |
+
for (std::string & antiprompt : params.antiprompt) {
|
766 |
+
size_t extra_padding = params.interactive ? 0 : 2;
|
767 |
+
size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
|
768 |
+
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
|
769 |
+
: 0;
|
770 |
+
|
771 |
+
if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
|
772 |
+
if (params.interactive) {
|
773 |
+
is_interacting = true;
|
774 |
+
}
|
775 |
+
is_antiprompt = true;
|
776 |
+
break;
|
777 |
+
}
|
778 |
+
}
|
779 |
+
|
780 |
+
// check for reverse prompt using special tokens
|
781 |
+
llama_token last_token = llama_sampling_last(ctx_sampling);
|
782 |
+
for (std::vector<llama_token> ids : antiprompt_ids) {
|
783 |
+
if (ids.size() == 1 && last_token == ids[0]) {
|
784 |
+
if (params.interactive) {
|
785 |
+
is_interacting = true;
|
786 |
+
}
|
787 |
+
is_antiprompt = true;
|
788 |
+
break;
|
789 |
+
}
|
790 |
+
}
|
791 |
+
|
792 |
+
if (is_antiprompt) {
|
793 |
+
LOG("found antiprompt: %s\n", last_output.c_str());
|
794 |
+
}
|
795 |
+
}
|
796 |
+
|
797 |
+
// deal with end of text token in interactive mode
|
798 |
+
if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
|
799 |
+
LOG("found EOS token\n");
|
800 |
+
|
801 |
+
if (params.interactive) {
|
802 |
+
if (!params.antiprompt.empty()) {
|
803 |
+
// tokenize and inject first reverse prompt
|
804 |
+
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
|
805 |
+
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
806 |
+
is_antiprompt = true;
|
807 |
+
}
|
808 |
+
|
809 |
+
is_interacting = true;
|
810 |
+
printf("\n");
|
811 |
+
} else if (params.instruct || params.chatml) {
|
812 |
+
is_interacting = true;
|
813 |
+
}
|
814 |
+
}
|
815 |
+
|
816 |
+
if (n_past > 0 && is_interacting) {
|
817 |
+
LOG("waiting for user input\n");
|
818 |
+
|
819 |
+
if (params.instruct || params.chatml) {
|
820 |
+
printf("\n> ");
|
821 |
+
}
|
822 |
+
|
823 |
+
if (params.input_prefix_bos) {
|
824 |
+
LOG("adding input prefix BOS token\n");
|
825 |
+
embd_inp.push_back(llama_token_bos(model));
|
826 |
+
}
|
827 |
+
|
828 |
+
std::string buffer;
|
829 |
+
if (!params.input_prefix.empty()) {
|
830 |
+
LOG("appending input prefix: '%s'\n", params.input_prefix.c_str());
|
831 |
+
printf("%s", params.input_prefix.c_str());
|
832 |
+
}
|
833 |
+
|
834 |
+
// color user input only
|
835 |
+
console::set_display(console::user_input);
|
836 |
+
display = params.display_prompt;
|
837 |
+
|
838 |
+
std::string line;
|
839 |
+
bool another_line = true;
|
840 |
+
do {
|
841 |
+
another_line = console::readline(line, params.multiline_input);
|
842 |
+
buffer += line;
|
843 |
+
} while (another_line);
|
844 |
+
|
845 |
+
// done taking input, reset color
|
846 |
+
console::set_display(console::reset);
|
847 |
+
display = true;
|
848 |
+
|
849 |
+
// Add tokens to embd only if the input buffer is non-empty
|
850 |
+
// Entering a empty line lets the user pass control back
|
851 |
+
if (buffer.length() > 1) {
|
852 |
+
// append input suffix if any
|
853 |
+
if (!params.input_suffix.empty()) {
|
854 |
+
LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());
|
855 |
+
printf("%s", params.input_suffix.c_str());
|
856 |
+
}
|
857 |
+
|
858 |
+
LOG("buffer: '%s'\n", buffer.c_str());
|
859 |
+
|
860 |
+
const size_t original_size = embd_inp.size();
|
861 |
+
|
862 |
+
// instruct mode: insert instruction prefix
|
863 |
+
if (params.instruct && !is_antiprompt) {
|
864 |
+
LOG("inserting instruction prefix\n");
|
865 |
+
n_consumed = embd_inp.size();
|
866 |
+
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
|
867 |
+
}
|
868 |
+
// chatml mode: insert user chat prefix
|
869 |
+
if (params.chatml && !is_antiprompt) {
|
870 |
+
LOG("inserting chatml prefix\n");
|
871 |
+
n_consumed = embd_inp.size();
|
872 |
+
embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end());
|
873 |
+
}
|
874 |
+
if (params.escape) {
|
875 |
+
process_escapes(buffer);
|
876 |
+
}
|
877 |
+
|
878 |
+
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
879 |
+
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
|
880 |
+
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
|
881 |
+
|
882 |
+
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
|
883 |
+
|
884 |
+
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
|
885 |
+
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
886 |
+
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
887 |
+
|
888 |
+
// instruct mode: insert response suffix
|
889 |
+
if (params.instruct) {
|
890 |
+
LOG("inserting instruction suffix\n");
|
891 |
+
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
|
892 |
+
}
|
893 |
+
// chatml mode: insert assistant chat suffix
|
894 |
+
if (params.chatml) {
|
895 |
+
LOG("inserting chatml suffix\n");
|
896 |
+
embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
|
897 |
+
}
|
898 |
+
|
899 |
+
for (size_t i = original_size; i < embd_inp.size(); ++i) {
|
900 |
+
const llama_token token = embd_inp[i];
|
901 |
+
output_tokens.push_back(token);
|
902 |
+
output_ss << llama_token_to_piece(ctx, token);
|
903 |
+
}
|
904 |
+
|
905 |
+
n_remain -= line_inp.size();
|
906 |
+
LOG("n_remain: %d\n", n_remain);
|
907 |
+
} else {
|
908 |
+
LOG("empty line, passing control back\n");
|
909 |
+
}
|
910 |
+
|
911 |
+
input_echo = false; // do not echo this again
|
912 |
+
}
|
913 |
+
|
914 |
+
if (n_past > 0) {
|
915 |
+
if (is_interacting) {
|
916 |
+
llama_sampling_reset(ctx_sampling);
|
917 |
+
}
|
918 |
+
is_interacting = false;
|
919 |
+
}
|
920 |
+
}
|
921 |
+
|
922 |
+
// end of text token
|
923 |
+
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
|
924 |
+
// LOG_TEE(" [end of text]\n");
|
925 |
+
break;
|
926 |
+
}
|
927 |
+
|
928 |
+
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
|
929 |
+
// We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size).
|
930 |
+
if (params.interactive && n_remain <= 0 && params.n_predict >= 0) {
|
931 |
+
n_remain = params.n_predict;
|
932 |
+
is_interacting = true;
|
933 |
+
}
|
934 |
+
}
|
935 |
+
|
936 |
+
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
|
937 |
+
// LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
|
938 |
+
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
939 |
+
}
|
940 |
+
|
941 |
+
// llama_print_timings(ctx);
|
942 |
+
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
943 |
+
|
944 |
+
if (ctx_guidance) { llama_free(ctx_guidance); }
|
945 |
+
llama_free(ctx);
|
946 |
+
llama_free_model(model);
|
947 |
+
|
948 |
+
llama_sampling_free(ctx_sampling);
|
949 |
+
llama_backend_free();
|
950 |
+
|
951 |
+
|
952 |
+
|
953 |
+
return 0;
|
954 |
+
}
|