|
#pragma once |
|
|
|
#include "llama.h" |
|
|
|
#include <array> |
|
#include <vector> |
|
|
|
|
|
|
|
struct llama_ubatch { |
|
bool equal_seqs; |
|
|
|
|
|
uint32_t n_tokens; |
|
uint32_t n_seq_tokens; |
|
uint32_t n_seqs; |
|
|
|
llama_token * token; |
|
float * embd; |
|
llama_pos * pos; |
|
int32_t * n_seq_id; |
|
llama_seq_id ** seq_id; |
|
int8_t * output; |
|
}; |
|
|
|
struct llama_sbatch_seq { |
|
int32_t n_seq_id; |
|
|
|
llama_seq_id * seq_id; |
|
|
|
size_t offset; |
|
size_t length; |
|
}; |
|
|
|
|
|
struct llama_sbatch { |
|
|
|
size_t n_tokens; |
|
|
|
size_t n_embd; |
|
|
|
bool logits_all; |
|
|
|
|
|
std::vector<size_t> ids; |
|
|
|
std::vector<size_t> out_ids; |
|
std::vector<llama_sbatch_seq> seq; |
|
|
|
const llama_batch * batch = nullptr; |
|
|
|
|
|
std::vector<llama_token> ubatch_token; |
|
std::vector<float> ubatch_embd; |
|
std::vector<llama_pos> ubatch_pos; |
|
std::vector<int32_t> ubatch_n_seq_id; |
|
std::vector<llama_seq_id *> ubatch_seq_id; |
|
std::vector<int8_t> ubatch_output; |
|
|
|
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); |
|
|
|
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length); |
|
|
|
|
|
llama_ubatch split_simple(size_t n_ubatch); |
|
|
|
|
|
llama_ubatch split_equal(size_t n_ubatch); |
|
|
|
|
|
llama_ubatch split_seq(size_t n_ubatch); |
|
|
|
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); |
|
}; |
|
|
|
|
|
struct llama_batch_allocr { |
|
struct llama_batch batch; |
|
|
|
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; |
|
std::vector<llama_pos> pos; |
|
std::vector<int32_t> n_seq_id; |
|
std::vector<llama_seq_id *> seq_id; |
|
std::vector<int8_t> logits; |
|
|
|
|
|
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); |
|
}; |
|
|