//===----------------------------------------------------------------------===// // // Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved. // // TPU-MLIR is licensed under the 2-Clause BSD License except for the // third-party components. // //===----------------------------------------------------------------------===// #include #include #include #include #include #include #include #include #include "memory.h" #include "bmruntime_interface.h" #include #include #include #include #include static const uint16_t ATTENTION_MASK = 0xF0E2; class Llama3 { public: void init(const std::vector &devid, std::string model_path); void deinit(); int forward_first(std::vector &tokens); int forward_next(); std::vector generate(std::vector &history_tokens, int EOS); std::mt19937 sgen; Llama3() : sgen(std::random_device()()){}; private: void net_launch(const bm_net_info_t *net, int stage_idx = 0); inline void d2d(bm_device_mem_t &dst, bm_device_mem_t &src); void head_launch(const bm_net_info_t *net, bm_device_mem_t &logits_mem); int greedy_search(const bm_net_info_t *net, bm_device_mem_t &logits_mem); int penalty_sample(const bm_net_info_t *net, bm_device_mem_t &logits_mem); public: int token_length; int SEQLEN; // read from bmodel int NUM_LAYERS; // read from bmodel bool io_alone; std::vector visited_tokens; // generation float temperature; float top_p; float repeat_penalty; int repeat_last_n; int max_new_tokens; std::string generation_mode; std::string prompt_mode; private: std::vector handles; bm_handle_t bm_handle; void *p_bmrt; std::vector net_blocks; std::vector net_blocks_cache; const bm_net_info_t *net_embed; const bm_net_info_t *net_embed_cache; const bm_net_info_t *net_lm, *net_greedy_head, *net_penalty_sample_head; std::vector past_key; std::vector past_value; }; void Llama3::net_launch(const bm_net_info_t *net, int stage_idx) { std::vector in_tensors(net->input_num); std::vector out_tensors(net->output_num); for (int i = 0; i < net->input_num; i++) { bmrt_tensor_with_device( &in_tensors[i], net->stages[stage_idx].input_mems[i], net->input_dtypes[i], net->stages[stage_idx].input_shapes[i]); } for (int i = 0; i < net->output_num; i++) { bmrt_tensor_with_device( &out_tensors[i], net->stages[stage_idx].output_mems[i], net->output_dtypes[i], net->stages[stage_idx].output_shapes[i]); } auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(), net->input_num, out_tensors.data(), net->output_num, true, false); assert(ret); bm_thread_sync(bm_handle); } void Llama3::d2d(bm_device_mem_t &dst, bm_device_mem_t &src) { bm_memcpy_d2d_byte(bm_handle, dst, 0, src, 0, bm_mem_get_device_size(src)); } void Llama3::init(const std::vector &devices, std::string model_path) { // request bm_handle std::cout << "Device [ "; for (auto d : devices) { std::cout << d << " "; } std::cout << "] loading ....\n"; for (auto d : devices) { bm_handle_t h; bm_status_t status = bm_dev_request(&h, d); assert(BM_SUCCESS == status); handles.push_back(h); } bm_handle = handles[0]; // create bmruntime #ifdef SOC_TARGET p_bmrt = bmrt_create(handles[0]); #else p_bmrt = bmrt_create_ex(handles.data(), handles.size()); #endif assert(NULL != p_bmrt); // load bmodel by file printf("Model[%s] loading ....\n", model_path.c_str()); bool ret = bmrt_load_bmodel(p_bmrt, model_path.c_str()); assert(true == ret); printf("Done!\n"); // net embed and lm_head net_embed = bmrt_get_network_info(p_bmrt, "embedding"); net_embed_cache = bmrt_get_network_info(p_bmrt, "embedding_cache"); net_lm = bmrt_get_network_info(p_bmrt, "lm_head"); net_greedy_head = bmrt_get_network_info(p_bmrt, "greedy_head"); net_penalty_sample_head = bmrt_get_network_info(p_bmrt, "penalty_sample_head"); SEQLEN = net_embed->stages[0].input_shapes[0].dims[1]; // real seqlen auto num_nets = bmrt_get_network_number(p_bmrt); NUM_LAYERS = (num_nets - 5) / 2; // resize visited_tokens.resize(SEQLEN); // net blocks for (int i = 0; i < NUM_LAYERS; i++) { auto block_name = "block_" + std::to_string(i); auto cache_name = "block_cache_" + std::to_string(i); net_blocks.emplace_back(bmrt_get_network_info(p_bmrt, block_name.c_str())); net_blocks_cache.emplace_back( bmrt_get_network_info(p_bmrt, cache_name.c_str())); } // kv cache past_key.resize(NUM_LAYERS); past_value.resize(NUM_LAYERS); auto addr_mode = net_blocks_cache[0]->addr_mode; io_alone = addr_mode == 1; for (int i = 0; i < NUM_LAYERS; i++) { assert(addr_mode == net_blocks_cache[i]->addr_mode); if (io_alone) { past_key[i] = net_blocks_cache[i]->stages[0].input_mems[3]; past_value[i] = net_blocks_cache[i]->stages[0].input_mems[4]; } else { auto ret = bm_malloc_device_byte(bm_handle, &past_key[i], net_blocks_cache[i]->max_input_bytes[3]); assert(BM_SUCCESS == ret); ret = bm_malloc_device_byte(bm_handle, &past_value[i], net_blocks_cache[i]->max_input_bytes[4]); assert(BM_SUCCESS == ret); } } } void Llama3::deinit() { if (false == io_alone) { for (int i = 0; i < NUM_LAYERS; i++) { bm_free_device(bm_handle, past_key[i]); bm_free_device(bm_handle, past_value[i]); } } bmrt_destroy(p_bmrt); for (auto h : handles) { bm_dev_free(h); } } void Llama3::head_launch(const bm_net_info_t *net, bm_device_mem_t &logits_mem) { std::vector in_tensors(net->input_num); std::vector out_tensors(net->output_num); bmrt_tensor_with_device( &in_tensors[0], logits_mem, net->input_dtypes[0], net->stages[0].input_shapes[0]); for (int i = 1; i < net->input_num; i++) { bmrt_tensor_with_device( &in_tensors[i], net->stages[0].input_mems[i], net->input_dtypes[i], net->stages[0].input_shapes[i]); } for (int i = 0; i < net->output_num; i++) { bmrt_tensor_with_device( &out_tensors[i], net->stages[0].output_mems[i], net->output_dtypes[i], net->stages[0].output_shapes[i]); } auto ret = bmrt_launch_tensor_ex(p_bmrt, net->name, in_tensors.data(), net->input_num, out_tensors.data(), net->output_num, true, false); assert(ret); bm_thread_sync(bm_handle); } int Llama3::greedy_search(const bm_net_info_t *net, bm_device_mem_t &logits_mem) { auto &out_mem = net->stages[0].output_mems[0]; head_launch(net, logits_mem); int token = 0; bm_memcpy_d2s(bm_handle, (void *)&token, out_mem); return token; } int Llama3::penalty_sample(const bm_net_info_t *net, bm_device_mem_t &logits_mem) { auto &in1_mem = net->stages[0].input_mems[1]; auto &in2_mem = net->stages[0].input_mems[2]; auto &in3_mem = net->stages[0].input_mems[3]; auto &in4_mem = net->stages[0].input_mems[4]; auto &out0_mem = net->stages[0].output_mems[0]; auto &out1_mem = net->stages[0].output_mems[1]; // repeat_penalty + top_p + top_k + temperature std::vector generated_tokens(SEQLEN, visited_tokens[token_length - 1]); repeat_last_n = std::min(repeat_last_n, token_length); std::copy(visited_tokens.begin() + token_length - repeat_last_n, visited_tokens.begin() + token_length, generated_tokens.begin()); bm_memcpy_s2d(bm_handle, in1_mem, (void *)generated_tokens.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)&top_p); bm_memcpy_s2d(bm_handle, in3_mem, (void *)&temperature); bm_memcpy_s2d(bm_handle, in4_mem, (void *)&repeat_penalty); // inference head_launch(net, logits_mem); // get logit & token int candidate_num = net->stages[0].output_shapes[0].dims[1]; std::vector probs(candidate_num); bm_memcpy_d2s(bm_handle, probs.data(), out0_mem); std::vector tokens(candidate_num); bm_memcpy_d2s(bm_handle, tokens.data(), out1_mem); // penalty_sample std::discrete_distribution<> dist(probs.begin(), probs.end()); return tokens[dist(sgen)]; } int Llama3::forward_first(std::vector &tokens) { std::vector position_id(SEQLEN, 0); std::vector attention_mask(SEQLEN * SEQLEN, ATTENTION_MASK); std::copy(tokens.begin(), tokens.end(), visited_tokens.data()); token_length = tokens.size(); for (int i = 0; i < token_length; i++) { position_id[i] = i; } for (int i = 0; i < token_length; i++) { for (int j = 0; j < SEQLEN; j++) { if (j <= i) { attention_mask[i * SEQLEN + j] = 0; } } } // forward embeding auto &in_mem = net_embed->stages[0].input_mems[0]; auto &out_mem = net_embed->stages[0].output_mems[0]; bm_memcpy_s2d(bm_handle, in_mem, (void *)visited_tokens.data()); net_launch(net_embed); // prefil embedding // forward blocks for (int idx = 0; idx < NUM_LAYERS; idx++) { auto &in0_mem = net_blocks[idx]->stages[0].input_mems[0]; auto &in1_mem = net_blocks[idx]->stages[0].input_mems[1]; auto &in2_mem = net_blocks[idx]->stages[0].input_mems[2]; d2d(in0_mem, out_mem); if (idx == 0) { // only first time need copy bm_memcpy_s2d(bm_handle, in1_mem, (void *)position_id.data()); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } net_launch(net_blocks[idx]); out_mem = net_blocks[idx]->stages[0].output_mems[0]; d2d(past_key[idx], net_blocks[idx]->stages[0].output_mems[1]); d2d(past_value[idx], net_blocks[idx]->stages[0].output_mems[2]); } // forward lmhead int bytes = out_mem.size / SEQLEN; auto &lm_in_mem = net_lm->stages[0].input_mems[0]; auto &lm_out_mem = net_lm->stages[0].output_mems[0]; bm_memcpy_d2d_byte(bm_handle, lm_in_mem, 0, out_mem, (token_length - 1) * bytes, bytes); net_launch(net_lm); int token = 0; if (generation_mode == "greedy") { token = greedy_search(net_greedy_head, lm_out_mem); } else if (generation_mode == "penalty_sample") { token = penalty_sample(net_penalty_sample_head, lm_out_mem); } visited_tokens[token_length] = token; token_length += 1; return token; } int Llama3::forward_next() { int cur_token = visited_tokens[token_length - 1]; std::vector attention_mask(SEQLEN + 1, 0); for (int i = token_length - 1; i < SEQLEN; i++) { attention_mask[i] = ATTENTION_MASK; } int32_t position_id = token_length - 1; // embedding auto &in_mem = net_embed_cache->stages[0].input_mems[0]; auto &out_mem = net_embed_cache->stages[0].output_mems[0]; bm_memcpy_s2d(bm_handle, in_mem, (void *)&cur_token); net_launch(net_embed_cache); // blocks int bytes = bm_mem_get_device_size(net_blocks_cache[0]->stages[0].output_mems[1]); int token_offset = (token_length - 1) * bytes; for (int idx = 0; idx < NUM_LAYERS; idx++) { auto &in0_mem = net_blocks_cache[idx]->stages[0].input_mems[0]; auto &in1_mem = net_blocks_cache[idx]->stages[0].input_mems[1]; auto &in2_mem = net_blocks_cache[idx]->stages[0].input_mems[2]; auto &in3_mem = net_blocks_cache[idx]->stages[0].input_mems[3]; auto &in4_mem = net_blocks_cache[idx]->stages[0].input_mems[4]; auto &out0_mem = net_blocks_cache[idx]->stages[0].output_mems[0]; auto &out1_mem = net_blocks_cache[idx]->stages[0].output_mems[1]; auto &out2_mem = net_blocks_cache[idx]->stages[0].output_mems[2]; d2d(in0_mem, out_mem); if (io_alone) { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } else { d2d(in1_mem, net_blocks_cache[0]->stages[0].input_mems[1]); d2d(in2_mem, net_blocks_cache[0]->stages[0].input_mems[2]); } } else { if (idx == 0) { bm_memcpy_s2d(bm_handle, in1_mem, (void *)&position_id); bm_memcpy_s2d(bm_handle, in2_mem, (void *)attention_mask.data()); } d2d(in3_mem, past_key[idx]); d2d(in4_mem, past_value[idx]); } net_launch(net_blocks_cache[idx]); out_mem = out0_mem; bm_memcpy_d2d_byte(bm_handle, past_key[idx], token_offset, out1_mem, 0, bytes); bm_memcpy_d2d_byte(bm_handle, past_value[idx], token_offset, out2_mem, 0, bytes); } // forward lmhead auto &lm_in_mem = net_lm->stages[0].input_mems[0]; auto &lm_out_mem = net_lm->stages[0].output_mems[0]; d2d(lm_in_mem, out_mem); net_launch(net_lm); int token = 0; if (generation_mode == "greedy") { token = greedy_search(net_greedy_head, lm_out_mem); } else if (generation_mode == "penalty_sample") { token = penalty_sample(net_penalty_sample_head, lm_out_mem); } visited_tokens[token_length] = token; token_length += 1; return token; } std::vector Llama3::generate(std::vector &history_tokens, int EOS) { if (history_tokens.empty()) { printf("Sorry: your question is empty!!\n"); history_tokens.clear(); return {}; } // make sure token not too large if ((int)history_tokens.size() > SEQLEN - 10) { history_tokens.clear(); printf("Error: your question is too large!\n"); return {}; } std::vector result_tokens; int token = forward_first(history_tokens); while (token != EOS && token_length < SEQLEN) { result_tokens.emplace_back(token); token = forward_next(); } return result_tokens; } PYBIND11_MODULE(chat, m) { pybind11::class_(m, "Llama3") .def(pybind11::init<>()) .def("init", &Llama3::init) .def("forward_first", &Llama3::forward_first) .def("forward_next", &Llama3::forward_next) .def("generate", &Llama3::generate) .def("deinit", &Llama3::deinit) .def_readwrite("SEQLEN", &Llama3::SEQLEN) // read SEQLEN in pipeline.py .def_readwrite("token_length", &Llama3::token_length) .def_readwrite("temperature", &Llama3::temperature) .def_readwrite("top_p", &Llama3::top_p) .def_readwrite("repeat_penalty", &Llama3::repeat_penalty) .def_readwrite("repeat_last_n", &Llama3::repeat_last_n) .def_readwrite("max_new_tokens", &Llama3::max_new_tokens) .def_readwrite("generation_mode", &Llama3::generation_mode) .def_readwrite("prompt_mode", &Llama3::prompt_mode); }