Spaces:
Sleeping
Sleeping
| use axum::{ | |
| Router, | |
| routing::{get, post}, | |
| extract::DefaultBodyLimit, | |
| response::{IntoResponse, Response, Json}, | |
| }; | |
| use tracing::{debug, error}; | |
| use tower_http::trace::TraceLayer; | |
| use std::sync::Arc; | |
| use tokio::sync::oneshot; | |
| use crate::proxy::TokenManager; | |
| /// Axum 应用状态 | |
| pub struct AppState { | |
| pub token_manager: Arc<TokenManager>, | |
| pub anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, | |
| pub openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, | |
| pub custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, | |
| pub request_timeout: u64, // API 请求超时(秒) | |
| pub thought_signature_map: Arc<tokio::sync::Mutex<std::collections::HashMap<String, String>>>, // 思维链签名映射 (ID -> Signature) | |
| pub upstream_proxy: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>, | |
| pub upstream: Arc<crate::proxy::upstream::client::UpstreamClient>, | |
| } | |
| /// Axum 服务器实例 | |
| pub struct AxumServer { | |
| shutdown_tx: Option<oneshot::Sender<()>>, | |
| anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, | |
| openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, | |
| custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>, | |
| proxy_state: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>, | |
| } | |
| impl AxumServer { | |
| pub async fn update_mapping(&self, config: &crate::proxy::config::ProxyConfig) { | |
| { | |
| let mut m = self.anthropic_mapping.write().await; | |
| *m = config.anthropic_mapping.clone(); | |
| } | |
| { | |
| let mut m = self.openai_mapping.write().await; | |
| *m = config.openai_mapping.clone(); | |
| } | |
| { | |
| let mut m = self.custom_mapping.write().await; | |
| *m = config.custom_mapping.clone(); | |
| } | |
| tracing::info!("模型映射 (Anthropic/OpenAI/Custom) 已全量热更新"); | |
| } | |
| /// 更新代理配置 | |
| pub async fn update_proxy(&self, new_config: crate::proxy::config::UpstreamProxyConfig) { | |
| let mut proxy = self.proxy_state.write().await; | |
| *proxy = new_config; | |
| tracing::info!("上游代理配置已热更新"); | |
| } | |
| /// 启动 Axum 服务器 | |
| pub async fn start( | |
| host: String, | |
| port: u16, | |
| token_manager: Arc<TokenManager>, | |
| anthropic_mapping: std::collections::HashMap<String, String>, | |
| openai_mapping: std::collections::HashMap<String, String>, | |
| custom_mapping: std::collections::HashMap<String, String>, | |
| _request_timeout: u64, | |
| upstream_proxy: crate::proxy::config::UpstreamProxyConfig, | |
| ) -> Result<(Self, tokio::task::JoinHandle<()>), String> { | |
| let mapping_state = Arc::new(tokio::sync::RwLock::new(anthropic_mapping)); | |
| let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(openai_mapping)); | |
| let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(custom_mapping)); | |
| let proxy_state = Arc::new(tokio::sync::RwLock::new(upstream_proxy.clone())); | |
| let state = AppState { | |
| token_manager: token_manager.clone(), | |
| anthropic_mapping: mapping_state.clone(), | |
| openai_mapping: openai_mapping_state.clone(), | |
| custom_mapping: custom_mapping_state.clone(), | |
| request_timeout: 300, // 5分钟超时 | |
| thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())), | |
| upstream_proxy: proxy_state.clone(), | |
| upstream: Arc::new(crate::proxy::upstream::client::UpstreamClient::new(Some(upstream_proxy.clone()))), | |
| }; | |
| // 构建路由 - 使用新架构的 handlers! | |
| use crate::proxy::handlers; | |
| // 构建路由 | |
| let app = Router::new() | |
| // OpenAI Protocol | |
| .route("/v1/models", get(handlers::openai::handle_list_models)) | |
| .route("/v1/chat/completions", post(handlers::openai::handle_chat_completions)) | |
| .route("/v1/completions", post(handlers::openai::handle_completions)) | |
| .route("/v1/responses", post(handlers::openai::handle_completions)) // 兼容 Codex CLI | |
| // Claude Protocol | |
| .route("/v1/messages", post(handlers::claude::handle_messages)) | |
| .route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens)) | |
| .route("/v1/models/claude", get(handlers::claude::handle_list_models)) | |
| // Gemini Protocol (Native) | |
| .route("/v1beta/models", get(handlers::gemini::handle_list_models)) | |
| // Handle both GET (get info) and POST (generateContent with colon) at the same route | |
| .route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate)) | |
| .route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens)) // Specific route priority | |
| .route("/healthz", get(health_check_handler)) | |
| .layer(DefaultBodyLimit::max(100 * 1024 * 1024)) | |
| .layer(TraceLayer::new_for_http()) | |
| .layer(axum::middleware::from_fn(crate::proxy::middleware::auth_middleware)) | |
| .layer(crate::proxy::middleware::cors_layer()) | |
| .with_state(state); | |
| // 绑定地址 | |
| let addr = format!("{}:{}", host, port); | |
| let listener = tokio::net::TcpListener::bind(&addr) | |
| .await | |
| .map_err(|e| format!("地址 {} 绑定失败: {}", addr, e))?; | |
| tracing::info!("反代服务器启动在 http://{}", addr); | |
| // 创建关闭通道 | |
| let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); | |
| let server_instance = Self { | |
| shutdown_tx: Some(shutdown_tx), | |
| anthropic_mapping: mapping_state.clone(), | |
| openai_mapping: openai_mapping_state.clone(), | |
| custom_mapping: custom_mapping_state.clone(), | |
| proxy_state, | |
| }; | |
| // 在新任务中启动服务器 | |
| let handle = tokio::spawn(async move { | |
| use hyper_util::rt::TokioIo; | |
| use hyper::server::conn::http1; | |
| use hyper_util::service::TowerToHyperService; | |
| loop { | |
| tokio::select! { | |
| res = listener.accept() => { | |
| match res { | |
| Ok((stream, _)) => { | |
| let io = TokioIo::new(stream); | |
| let service = TowerToHyperService::new(app.clone()); | |
| tokio::task::spawn(async move { | |
| if let Err(err) = http1::Builder::new() | |
| .serve_connection(io, service) | |
| .with_upgrades() // 支持 WebSocket (如果以后需要) | |
| .await | |
| { | |
| debug!("连接处理结束或出错: {:?}", err); | |
| } | |
| }); | |
| } | |
| Err(e) => { | |
| error!("接收连接失败: {:?}", e); | |
| } | |
| } | |
| } | |
| _ = &mut shutdown_rx => { | |
| tracing::info!("反代服务器停止监听"); | |
| break; | |
| } | |
| } | |
| } | |
| }); | |
| Ok(( | |
| server_instance, | |
| handle, | |
| )) | |
| } | |
| /// 停止服务器 | |
| pub fn stop(mut self) { | |
| if let Some(tx) = self.shutdown_tx.take() { | |
| let _ = tx.send(()); | |
| } | |
| } | |
| } | |
| // ===== API 处理器 (旧代码已移除,由 src/proxy/handlers/* 接管) ===== | |
| /// 健康检查处理器 | |
| async fn health_check_handler() -> Response { | |
| Json(serde_json::json!({ | |
| "status": "ok" | |
| })).into_response() | |
| } | |