| | #include "llama.h" |
| | #include <vector> |
| | #include <string> |
| | #include <cstring> |
| | #include <cstdio> |
| | #include <cstdlib> |
| |
|
| | |
| | static llama_model* g_model = nullptr; |
| | static llama_context* g_ctx = nullptr; |
| | static llama_sampler* g_smpl = nullptr; |
| |
|
| | extern "C" { |
| |
|
| | |
| | bool init_model(const char* model_path) { |
| | llama_backend_init(); |
| | |
| | llama_model_params model_params = llama_model_default_params(); |
| | model_params.n_gpu_layers = 0; |
| | g_model = llama_model_load_from_file(model_path, model_params); |
| | |
| | if (!g_model) return false; |
| |
|
| | llama_context_params ctx_params = llama_context_default_params(); |
| | ctx_params.n_ctx = 4096; |
| | ctx_params.n_batch = 512; |
| | ctx_params.n_threads = 8; |
| | ctx_params.n_threads_batch = 8; |
| | ctx_params.n_seq_max = 16; |
| | g_ctx = llama_init_from_model(g_model, ctx_params); |
| | |
| | if (!g_ctx) return false; |
| |
|
| | auto sparams = llama_sampler_chain_default_params(); |
| | g_smpl = llama_sampler_chain_init(sparams); |
| | llama_sampler_chain_add(g_smpl, llama_sampler_init_greedy()); |
| |
|
| | return true; |
| | } |
| |
|
| | |
| | void batch_add(llama_batch & batch, llama_token id, llama_pos pos, const std::vector<llama_seq_id> & seq_ids, bool logits) { |
| | batch.token[batch.n_tokens] = id; |
| | batch.pos[batch.n_tokens] = pos; |
| | batch.n_seq_id[batch.n_tokens] = seq_ids.size(); |
| | for (size_t i = 0; i < seq_ids.size(); ++i) { |
| | batch.seq_id[batch.n_tokens][i] = seq_ids[i]; |
| | } |
| | batch.logits[batch.n_tokens] = logits; |
| | batch.n_tokens++; |
| | } |
| |
|
| | |
| | static int g_count = 0; |
| | static int g_step = 0; |
| | static int g_max_tokens = 0; |
| | static std::vector<std::string> g_responses; |
| | static std::vector<bool> g_active; |
| | static std::vector<int> g_n_pos; |
| | static std::vector<int> g_logits_idx; |
| | static std::vector<std::vector<llama_token>> g_all_tokens; |
| | static llama_batch g_batch; |
| | static const llama_vocab* g_vocab = nullptr; |
| |
|
| | |
| | void start_batch(const char** prompts, int count, int max_tokens) { |
| | if (!g_ctx || count == 0) return; |
| |
|
| | g_vocab = llama_model_get_vocab(g_model); |
| | g_count = count; |
| | g_max_tokens = max_tokens; |
| | g_step = 0; |
| | |
| | |
| | g_responses.assign(count, ""); |
| | g_active.assign(count, true); |
| | g_n_pos.assign(count, 0); |
| | g_logits_idx.assign(count, -1); |
| | g_all_tokens.clear(); |
| |
|
| | |
| | for (int i = 0; i < count; i++) { |
| | int n_prompt = -llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), NULL, 0, true, true); |
| | std::vector<llama_token> tokens(n_prompt); |
| | llama_tokenize(g_vocab, prompts[i], strlen(prompts[i]), tokens.data(), tokens.size(), true, true); |
| | g_all_tokens.push_back(tokens); |
| | } |
| |
|
| | |
| | llama_memory_clear(llama_get_memory(g_ctx), true); |
| |
|
| | |
| | if (g_batch.token) llama_batch_free(g_batch); |
| | g_batch = llama_batch_init(4096, 0, 1); |
| |
|
| | |
| | g_batch.n_tokens = 0; |
| | for (int i = 0; i < count; i++) { |
| | for (size_t j = 0; j < g_all_tokens[i].size(); j++) { |
| | bool is_last = (j == g_all_tokens[i].size() - 1); |
| | if (is_last) g_logits_idx[i] = g_batch.n_tokens; |
| | batch_add(g_batch, g_all_tokens[i][j], g_n_pos[i]++, { (llama_seq_id)i }, is_last); |
| | } |
| | } |
| |
|
| | |
| | if (llama_decode(g_ctx, g_batch)) { |
| | fprintf(stderr, "Failed to decode prefill\n"); |
| | } |
| | } |
| |
|
| | |
| | |
| | bool decode_step(const char** results) { |
| | if (g_step >= g_max_tokens) return false; |
| |
|
| | g_batch.n_tokens = 0; |
| | bool any_active = false; |
| | std::vector<int> next_logits_idx(g_count, -1); |
| | int current_batch_pos = 0; |
| |
|
| | for (int i = 0; i < g_count; i++) { |
| | results[i] = nullptr; |
| | |
| | if (!g_active[i]) continue; |
| |
|
| | |
| | llama_token id = llama_sampler_sample(g_smpl, g_ctx, g_logits_idx[i]); |
| | llama_sampler_accept(g_smpl, id); |
| |
|
| | |
| | if (llama_vocab_is_eog(g_vocab, id) || g_n_pos[i] >= 4096) { |
| | g_active[i] = false; |
| | continue; |
| | } |
| |
|
| | |
| | static char buf[256]; |
| | int n = llama_token_to_piece(g_vocab, id, buf, sizeof(buf), 0, true); |
| | if (n < 0) { |
| | |
| | } else { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | } |
| | |
| | |
| | |
| | |
| | |
| | results[i] = strdup(buf); |
| |
|
| | next_logits_idx[i] = current_batch_pos++; |
| | batch_add(g_batch, id, g_n_pos[i]++, { (llama_seq_id)i }, true); |
| | any_active = true; |
| | } |
| |
|
| | if (!any_active) return false; |
| |
|
| | g_logits_idx = next_logits_idx; |
| | if (llama_decode(g_ctx, g_batch)) { |
| | return false; |
| | } |
| |
|
| | g_step++; |
| | return true; |
| | } |
| |
|
| | |
| | void cleanup() { |
| | if (g_smpl) llama_sampler_free(g_smpl); |
| | if (g_ctx) llama_free(g_ctx); |
| | if (g_model) llama_model_free(g_model); |
| | } |
| |
|
| | } |
| |
|