Spaces:
Sleeping
Sleeping
| namespace asio = boost::asio; | |
| namespace beast = boost::beast; | |
| namespace http = beast::http; | |
| using json = nlohmann::json; | |
| struct WorkerInfo { | |
| std::string model; | |
| int port = 0; | |
| pid_t pid = -1; | |
| std::string last_loaded; | |
| }; | |
| static std::string now_utc_iso() { | |
| std::time_t t = std::time(nullptr); | |
| std::tm tm{}; | |
| gmtime_r(&t, &tm); | |
| std::ostringstream oss; | |
| oss << std::put_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); | |
| return oss.str(); | |
| } | |
| static std::string get_env_or(const char *name, const std::string &fallback) { | |
| const char *v = std::getenv(name); | |
| return (v && *v) ? std::string(v) : fallback; | |
| } | |
| static int get_env_int_or(const char *name, int fallback) { | |
| const char *v = std::getenv(name); | |
| if (!v || !*v) return fallback; | |
| try { | |
| return std::stoi(v); | |
| } catch (...) { | |
| return fallback; | |
| } | |
| } | |
| static bool is_alive(pid_t pid) { | |
| if (pid <= 0) return false; | |
| return kill(pid, 0) == 0; | |
| } | |
| static void shutdown_worker(pid_t pid, int wait_seconds = 15) { | |
| if (pid <= 0) return; | |
| kill(pid, SIGTERM); | |
| const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(wait_seconds); | |
| while (std::chrono::steady_clock::now() < deadline) { | |
| int status = 0; | |
| pid_t r = waitpid(pid, &status, WNOHANG); | |
| if (r == pid) return; | |
| std::this_thread::sleep_for(std::chrono::milliseconds(200)); | |
| } | |
| kill(pid, SIGKILL); | |
| int status = 0; | |
| waitpid(pid, &status, 0); | |
| } | |
| class ModelManager { | |
| public: | |
| ModelManager() | |
| : _default_model(get_env_or("DEFAULT_MODEL", "QuantFactory/Qwen2.5-7B-Instruct-GGUF:q4_k_m")), | |
| _llama_server_bin(get_env_or("LLAMA_SERVER_BIN", "/usr/local/bin/llama-server")), | |
| _worker_host(get_env_or("WORKER_HOST", "127.0.0.1")), | |
| _worker_bind_host(get_env_or("WORKER_BIND_HOST", "0.0.0.0")), | |
| _base_port(get_env_int_or("WORKER_BASE_PORT", 8080)), | |
| _switch_timeout_sec(get_env_int_or("SWITCH_TIMEOUT_SEC", 300)), | |
| _n_ctx(get_env_int_or("MODEL_N_CTX", 8192)), | |
| _n_threads(get_env_int_or("MODEL_THREADS", 4)), | |
| _n_gpu_layers(get_env_int_or("MODEL_NGL", 0)), | |
| _n_batch(get_env_int_or("MODEL_BATCH", 128)), | |
| _n_ubatch(get_env_int_or("MODEL_UBATCH", 64)), | |
| _next_port(_base_port) {} | |
| bool initialize_default(std::string &error) { | |
| return switch_model(_default_model, error); | |
| } | |
| bool switch_model(const std::string &model, std::string &error) { | |
| { | |
| std::lock_guard<std::mutex> lock(_mu); | |
| if (_switch_in_progress) { | |
| error = "Switch already in progress"; | |
| return false; | |
| } | |
| if (_active && _active->model == model && is_alive(_active->pid)) { | |
| return true; | |
| } | |
| _switch_in_progress = true; | |
| } | |
| std::optional<WorkerInfo> old_worker; | |
| { | |
| std::lock_guard<std::mutex> lock(_mu); | |
| if (_active) old_worker = _active; | |
| } | |
| int port = allocate_port(); | |
| pid_t pid = spawn_worker(model, port); | |
| if (pid <= 0) { | |
| finish_switch(false); | |
| error = "Failed to start worker process"; | |
| return false; | |
| } | |
| if (!wait_until_ready(pid, port, _switch_timeout_sec)) { | |
| shutdown_worker(pid); | |
| finish_switch(false); | |
| error = "New model did not become ready in time"; | |
| return false; | |
| } | |
| WorkerInfo new_worker{model, port, pid, now_utc_iso()}; | |
| { | |
| std::lock_guard<std::mutex> lock(_mu); | |
| _active = new_worker; | |
| _switch_in_progress = false; | |
| } | |
| if (old_worker && old_worker->pid != pid) { | |
| shutdown_worker(old_worker->pid); | |
| } | |
| return true; | |
| } | |
| std::optional<WorkerInfo> active_worker() { | |
| std::lock_guard<std::mutex> lock(_mu); | |
| if (_active && is_alive(_active->pid)) return _active; | |
| return std::nullopt; | |
| } | |
| json models_view() { | |
| std::lock_guard<std::mutex> lock(_mu); | |
| json out; | |
| out["status"] = (_active && is_alive(_active->pid)) ? "ready" : "no_active_model"; | |
| out["switch_in_progress"] = _switch_in_progress; | |
| if (_active && is_alive(_active->pid)) { | |
| out["current_model"] = _active->model; | |
| out["last_loaded"] = _active->last_loaded; | |
| out["active_pid"] = _active->pid; | |
| out["active_port"] = _active->port; | |
| } else { | |
| out["current_model"] = nullptr; | |
| out["last_loaded"] = nullptr; | |
| out["active_pid"] = nullptr; | |
| out["active_port"] = nullptr; | |
| } | |
| return out; | |
| } | |
| private: | |
| std::mutex _mu; | |
| std::optional<WorkerInfo> _active; | |
| bool _switch_in_progress = false; | |
| std::string _default_model; | |
| std::string _llama_server_bin; | |
| std::string _worker_host; | |
| std::string _worker_bind_host; | |
| int _base_port; | |
| int _switch_timeout_sec; | |
| int _n_ctx; | |
| int _n_threads; | |
| int _n_gpu_layers; | |
| int _n_batch; | |
| int _n_ubatch; | |
| int _next_port; | |
| int allocate_port() { | |
| std::lock_guard<std::mutex> lock(_mu); | |
| return _next_port++; | |
| } | |
| void finish_switch(bool ok) { | |
| std::lock_guard<std::mutex> lock(_mu); | |
| if (!ok) _switch_in_progress = false; | |
| } | |
| pid_t spawn_worker(const std::string &model, int port) { | |
| pid_t pid = fork(); | |
| if (pid < 0) return -1; | |
| if (pid == 0) { | |
| setsid(); | |
| std::string port_s = std::to_string(port); | |
| std::string n_ctx_s = std::to_string(_n_ctx); | |
| std::string threads_s = std::to_string(_n_threads); | |
| std::string ngl_s = std::to_string(_n_gpu_layers); | |
| std::string batch_s = std::to_string(_n_batch); | |
| std::string ubatch_s = std::to_string(_n_ubatch); | |
| std::vector<std::string> args = { | |
| _llama_server_bin, | |
| "-hf", model, | |
| "--host", _worker_bind_host, | |
| "--port", port_s, | |
| "-c", n_ctx_s, | |
| "-t", threads_s, | |
| "-ngl", ngl_s, | |
| "--cont-batching", | |
| "-b", batch_s, | |
| "--ubatch-size", ubatch_s | |
| }; | |
| std::vector<char *> argv; | |
| argv.reserve(args.size() + 1); | |
| for (auto &s : args) argv.push_back(const_cast<char *>(s.c_str())); | |
| argv.push_back(nullptr); | |
| execvp(argv[0], argv.data()); | |
| _exit(127); | |
| } | |
| return pid; | |
| } | |
| bool wait_until_ready(pid_t pid, int port, int timeout_sec) { | |
| const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_sec); | |
| while (std::chrono::steady_clock::now() < deadline) { | |
| if (!is_alive(pid)) return false; | |
| try { | |
| auto [status, _] = http_get(port, "/"); | |
| if (status == 200) return true; | |
| } catch (...) { | |
| } | |
| std::this_thread::sleep_for(std::chrono::milliseconds(800)); | |
| } | |
| return false; | |
| } | |
| std::pair<int, std::string> http_get(int port, const std::string &target) { | |
| asio::io_context ioc; | |
| asio::ip::tcp::resolver resolver(ioc); | |
| beast::tcp_stream stream(ioc); | |
| auto const results = resolver.resolve(_worker_host, std::to_string(port)); | |
| stream.connect(results); | |
| http::request<http::string_body> req{http::verb::get, target, 11}; | |
| req.set(http::field::host, _worker_host); | |
| req.set(http::field::user_agent, "llm-manager"); | |
| http::write(stream, req); | |
| beast::flat_buffer buffer; | |
| http::response<http::string_body> res; | |
| http::read(stream, buffer, res); | |
| beast::error_code ec; | |
| stream.socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec); | |
| return {res.result_int(), res.body()}; | |
| } | |
| }; | |
| static std::atomic<uint64_t> g_req_id{1}; | |
| static void log_line(const std::string &line) { | |
| std::cout << "[" << now_utc_iso() << "] " << line << std::endl; | |
| } | |
| static std::string truncate_body(const std::string &body, size_t max_len = 2000) { | |
| if (body.size() <= max_len) return body; | |
| return body.substr(0, max_len) + "...[truncated]"; | |
| } | |
| static std::pair<int, std::string> forward_chat(const WorkerInfo &worker, const std::string &body) { | |
| asio::io_context ioc; | |
| asio::ip::tcp::resolver resolver(ioc); | |
| beast::tcp_stream stream(ioc); | |
| auto const results = resolver.resolve("127.0.0.1", std::to_string(worker.port)); | |
| stream.connect(results); | |
| http::request<http::string_body> req{http::verb::post, "/v1/chat/completions", 11}; | |
| req.set(http::field::host, "127.0.0.1"); | |
| req.set(http::field::content_type, "application/json"); | |
| req.set(http::field::user_agent, "llm-manager"); | |
| req.body() = body; | |
| req.prepare_payload(); | |
| http::write(stream, req); | |
| beast::flat_buffer buffer; | |
| http::response<http::string_body> res; | |
| http::read(stream, buffer, res); | |
| beast::error_code ec; | |
| stream.socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec); | |
| return {res.result_int(), res.body()}; | |
| } | |
| struct ProxiedGetResult { | |
| int status = 500; | |
| std::string body; | |
| std::string content_type = "text/plain; charset=utf-8"; | |
| std::string content_encoding; | |
| }; | |
| static ProxiedGetResult forward_get_to_worker(const WorkerInfo &worker, | |
| const std::string &target) { | |
| asio::io_context ioc; | |
| asio::ip::tcp::resolver resolver(ioc); | |
| beast::tcp_stream stream(ioc); | |
| auto const results = resolver.resolve("127.0.0.1", std::to_string(worker.port)); | |
| stream.connect(results); | |
| http::request<http::string_body> req{http::verb::get, target, 11}; | |
| req.set(http::field::host, "127.0.0.1"); | |
| req.set(http::field::user_agent, "llm-manager"); | |
| req.set(http::field::accept_encoding, "gzip, identity"); | |
| http::write(stream, req); | |
| beast::flat_buffer buffer; | |
| http::response<http::string_body> res; | |
| http::read(stream, buffer, res); | |
| beast::error_code ec; | |
| stream.socket().shutdown(asio::ip::tcp::socket::shutdown_both, ec); | |
| ProxiedGetResult out; | |
| out.status = res.result_int(); | |
| out.body = res.body(); | |
| if (res.base().find(http::field::content_type) != res.base().end()) { | |
| out.content_type = res.base()[http::field::content_type].to_string(); | |
| } | |
| if (res.base().find(http::field::content_encoding) != res.base().end()) { | |
| out.content_encoding = res.base()[http::field::content_encoding].to_string(); | |
| } | |
| return out; | |
| } | |
| template <typename Body, typename Allocator> | |
| http::response<http::string_body> handle_request( | |
| ModelManager &manager, | |
| http::request<Body, http::basic_fields<Allocator>> &&req) { | |
| const auto start = std::chrono::steady_clock::now(); | |
| const auto req_id = g_req_id.fetch_add(1); | |
| const std::string target = req.target().to_string(); | |
| const std::string method = req.method_string().to_string(); | |
| const std::string path = target.substr(0, target.find('?')); | |
| log_line("request_id=" + std::to_string(req_id) + " method=" + method + " path=" + target); | |
| if constexpr (std::is_same_v<Body, http::string_body>) { | |
| if (!req.body().empty()) { | |
| log_line("request_id=" + std::to_string(req_id) + " body=" + truncate_body(req.body())); | |
| } | |
| } | |
| auto json_response = [&](http::status status, const json &obj) { | |
| http::response<http::string_body> res{status, req.version()}; | |
| res.set(http::field::content_type, "application/json"); | |
| res.set(http::field::server, "llm-manager"); | |
| res.keep_alive(req.keep_alive()); | |
| res.body() = obj.dump(); | |
| res.prepare_payload(); | |
| auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>( | |
| std::chrono::steady_clock::now() - start).count(); | |
| log_line("request_id=" + std::to_string(req_id) + " status=" + std::to_string(res.result_int()) + | |
| " elapsed_ms=" + std::to_string(elapsed_ms)); | |
| return res; | |
| }; | |
| try { | |
| if (path == "/health" && req.method() == http::verb::get) { | |
| return json_response(http::status::ok, manager.models_view()); | |
| } | |
| if (path == "/models" && req.method() == http::verb::get) { | |
| return json_response(http::status::ok, manager.models_view()); | |
| } | |
| if (path == "/switch-model" && req.method() == http::verb::post) { | |
| std::string body(req.body().data(), req.body().size()); | |
| json j = json::parse(body, nullptr, false); | |
| if (j.is_discarded()) { | |
| return json_response(http::status::bad_request, {{"error", "Invalid JSON"}}); | |
| } | |
| std::string model; | |
| if (j.contains("model_name")) model = j["model_name"].get<std::string>(); | |
| if (j.contains("model")) model = j["model"].get<std::string>(); | |
| if (model.empty()) { | |
| return json_response(http::status::bad_request, {{"error", "Expected 'model' or 'model_name'"}}); | |
| } | |
| std::string err; | |
| bool ok = manager.switch_model(model, err); | |
| if (!ok) { | |
| auto status = (err == "Switch already in progress") ? http::status::conflict : http::status::internal_server_error; | |
| return json_response(status, {{"status", "error"}, {"error", err}}); | |
| } | |
| auto state = manager.models_view(); | |
| state["message"] = "Switched model successfully"; | |
| return json_response(http::status::ok, state); | |
| } | |
| if (path == "/v1/chat/completions" && req.method() == http::verb::post) { | |
| auto worker = manager.active_worker(); | |
| if (!worker) { | |
| return json_response(http::status::service_unavailable, {{"error", "No active model"}}); | |
| } | |
| auto [upstream_status, upstream_body] = forward_chat(*worker, req.body()); | |
| http::response<http::string_body> res{static_cast<http::status>(upstream_status), req.version()}; | |
| res.set(http::field::content_type, "application/json"); | |
| res.set(http::field::server, "llm-manager"); | |
| res.keep_alive(req.keep_alive()); | |
| res.body() = upstream_body; | |
| res.prepare_payload(); | |
| auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>( | |
| std::chrono::steady_clock::now() - start).count(); | |
| log_line("request_id=" + std::to_string(req_id) + " model=" + worker->model + | |
| " active_pid=" + std::to_string(worker->pid) + | |
| " active_port=" + std::to_string(worker->port) + | |
| " upstream_status=" + std::to_string(upstream_status) + | |
| " elapsed_ms=" + std::to_string(elapsed_ms)); | |
| return res; | |
| } | |
| // Proxy GET requests not handled by manager endpoints to active llama-server. | |
| // This enables llama.cpp UI/static routes (including "/"). | |
| if (req.method() == http::verb::get) { | |
| auto worker = manager.active_worker(); | |
| if (!worker) { | |
| return json_response(http::status::service_unavailable, {{"error", "No active model"}}); | |
| } | |
| auto upstream = forward_get_to_worker(*worker, target); | |
| http::response<http::string_body> res{ | |
| static_cast<http::status>(upstream.status), req.version()}; | |
| res.set(http::field::content_type, upstream.content_type); | |
| if (!upstream.content_encoding.empty()) { | |
| res.set(http::field::content_encoding, upstream.content_encoding); | |
| } | |
| res.set(http::field::server, "llm-manager"); | |
| res.keep_alive(req.keep_alive()); | |
| res.body() = upstream.body; | |
| res.prepare_payload(); | |
| auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>( | |
| std::chrono::steady_clock::now() - start) | |
| .count(); | |
| log_line("request_id=" + std::to_string(req_id) + | |
| " proxied_get model=" + worker->model + | |
| " upstream_status=" + std::to_string(upstream.status) + | |
| " elapsed_ms=" + std::to_string(elapsed_ms)); | |
| return res; | |
| } | |
| return json_response(http::status::not_found, {{"error", "Not found"}}); | |
| } catch (const std::exception &e) { | |
| return json_response(http::status::internal_server_error, {{"error", e.what()}}); | |
| } | |
| } | |
| void do_session(asio::ip::tcp::socket socket, ModelManager &manager) { | |
| try { | |
| beast::flat_buffer buffer; | |
| http::request<http::string_body> req; | |
| http::read(socket, buffer, req); | |
| auto res = handle_request(manager, std::move(req)); | |
| http::write(socket, res); | |
| beast::error_code ec; | |
| socket.shutdown(asio::ip::tcp::socket::shutdown_send, ec); | |
| } catch (...) { | |
| } | |
| } | |
| int main() { | |
| const auto bind_host = get_env_or("MANAGER_HOST", "0.0.0.0"); | |
| const int bind_port = get_env_int_or("MANAGER_PORT", 7860); | |
| ModelManager manager; | |
| std::string init_error; | |
| log_line("startup: loading default model"); | |
| if (!manager.initialize_default(init_error)) { | |
| log_line("startup: default model failed: " + init_error); | |
| } else { | |
| log_line("startup: default model ready"); | |
| } | |
| asio::io_context ioc{1}; | |
| asio::ip::tcp::acceptor acceptor{ioc, {asio::ip::make_address(bind_host), static_cast<unsigned short>(bind_port)}}; | |
| log_line("manager listening on " + bind_host + ":" + std::to_string(bind_port)); | |
| for (;;) { | |
| asio::ip::tcp::socket socket{ioc}; | |
| acceptor.accept(socket); | |
| std::thread(&do_session, std::move(socket), std::ref(manager)).detach(); | |
| } | |
| } | |