package proxy import ( "crypto/rsa" "crypto/x509" "encoding/pem" "io" "log" "net/http" "github.com/arpinfidel/p2p-llm/config" "github.com/arpinfidel/p2p-llm/db" ) type ProxyHandler struct { cfg *config.Config peerRepo db.PeerRepository queue chan Request maxParallelRequests int maxParallelPeerRequests int } type Request struct { W http.ResponseWriter R *http.Request } func NewProxyHandler(cfg *config.Config, peerRepo db.PeerRepository) *ProxyHandler { return &ProxyHandler{ cfg: cfg, peerRepo: peerRepo, queue: make(chan Request, 100), // Hardcoded queue size maxParallelRequests: cfg.MaxParallelRequests, maxParallelPeerRequests: 5, // Hardcoded peer requests limit } } func (h *ProxyHandler) Handle(w http.ResponseWriter, r *http.Request) { req := Request{W: w, R: r} h.queue <- req } func (h *ProxyHandler) Run() { // Get peers from config peers := h.cfg.TrustedPeers // Get peers from database dbPeers, err := h.peerRepo.ListTrustedPeers() if err != nil { log.Printf("Error getting peers from database: %v", err) } else { for _, p := range dbPeers { block, _ := pem.Decode([]byte(p.PublicKey)) if block == nil { continue } pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { continue } peers = append(peers, config.Peer{ URL: p.URL, PublicKey: pubKey.(*rsa.PublicKey), }) } } // Start workers for target URL for range h.maxParallelRequests { go func() { for req := range h.queue { h.Forward(req, h.cfg.TargetURL) } }() } // Start workers for each peer for _, peer := range peers { go func(p config.Peer) { for req := range h.queue { h.Forward(req, p.URL) } }(peer) } } func (h *ProxyHandler) Forward(req Request, url string) { // Create a new request to the target server targetReq, err := http.NewRequest(req.R.Method, url+req.R.URL.Path, req.R.Body) if err != nil { log.Printf("Error creating request: %v", err) http.Error(req.W, "Error creating request", http.StatusInternalServerError) return } // Copy headers from original request for name, values := range req.R.Header { for _, value := range values { targetReq.Header.Add(name, value) } } // Create HTTP client client := &http.Client{} // Send the request to the target server resp, err := client.Do(targetReq) if err != nil { log.Printf("Error forwarding request: %v", err) http.Error(req.W, "Error forwarding request", http.StatusBadGateway) return } defer resp.Body.Close() // Check if this is an SSE response isSSE := false for name, values := range resp.Header { for _, value := range values { req.W.Header().Add(name, value) if name == "Content-Type" && value == "text/event-stream" { isSSE = true } } } // Set response status code req.W.WriteHeader(resp.StatusCode) // Handle SSE responses differently if isSSE { // Set necessary headers for SSE req.W.Header().Set("Content-Type", "text/event-stream") req.W.Header().Set("Cache-Control", "no-cache") req.W.Header().Set("Connection", "keep-alive") req.W.Header().Set("Transfer-Encoding", "chunked") // Create a flusher if the ResponseWriter supports it flusher, ok := req.W.(http.Flusher) if !ok { log.Printf("ResponseWriter does not support flushing") http.Error(req.W, "Streaming unsupported", http.StatusInternalServerError) return } // Buffer for reading from response body buf := make([]byte, 1024) for { n, err := resp.Body.Read(buf) if n > 0 { // Write data to client if _, writeErr := req.W.Write(buf[:n]); writeErr != nil { log.Printf("Error writing to client: %v", writeErr) break } // Flush data immediately to client flusher.Flush() } if err != nil { if err != io.EOF { log.Printf("Error reading from response body: %v", err) } break } } } else { // For non-SSE responses, just copy the body if _, err := io.Copy(req.W, resp.Body); err != nil { log.Printf("Error copying response body: %v", err) } } }