Spaces:
Runtime error
Runtime error
| package proxy | |
| import ( | |
| "crypto/rsa" | |
| "crypto/x509" | |
| "encoding/pem" | |
| "io" | |
| "log" | |
| "net/http" | |
| "github.com/arpinfidel/p2p-llm/config" | |
| "github.com/arpinfidel/p2p-llm/db" | |
| ) | |
| type ProxyHandler struct { | |
| cfg *config.Config | |
| peerRepo db.PeerRepository | |
| queue chan Request | |
| maxParallelRequests int | |
| maxParallelPeerRequests int | |
| } | |
| type Request struct { | |
| W http.ResponseWriter | |
| R *http.Request | |
| } | |
| func NewProxyHandler(cfg *config.Config, peerRepo db.PeerRepository) *ProxyHandler { | |
| return &ProxyHandler{ | |
| cfg: cfg, | |
| peerRepo: peerRepo, | |
| queue: make(chan Request, 100), // Hardcoded queue size | |
| maxParallelRequests: cfg.MaxParallelRequests, | |
| maxParallelPeerRequests: 5, // Hardcoded peer requests limit | |
| } | |
| } | |
| func (h *ProxyHandler) Handle(w http.ResponseWriter, r *http.Request) { | |
| req := Request{W: w, R: r} | |
| h.queue <- req | |
| } | |
| func (h *ProxyHandler) Run() { | |
| // Get peers from config | |
| peers := h.cfg.TrustedPeers | |
| // Get peers from database | |
| dbPeers, err := h.peerRepo.ListTrustedPeers() | |
| if err != nil { | |
| log.Printf("Error getting peers from database: %v", err) | |
| } else { | |
| for _, p := range dbPeers { | |
| block, _ := pem.Decode([]byte(p.PublicKey)) | |
| if block == nil { | |
| continue | |
| } | |
| pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) | |
| if err != nil { | |
| continue | |
| } | |
| peers = append(peers, config.Peer{ | |
| URL: p.URL, | |
| PublicKey: pubKey.(*rsa.PublicKey), | |
| }) | |
| } | |
| } | |
| // Start workers for target URL | |
| for range h.maxParallelRequests { | |
| go func() { | |
| for req := range h.queue { | |
| h.Forward(req, h.cfg.TargetURL) | |
| } | |
| }() | |
| } | |
| // Start workers for each peer | |
| for _, peer := range peers { | |
| go func(p config.Peer) { | |
| for req := range h.queue { | |
| h.Forward(req, p.URL) | |
| } | |
| }(peer) | |
| } | |
| } | |
| func (h *ProxyHandler) Forward(req Request, url string) { | |
| // Create a new request to the target server | |
| targetReq, err := http.NewRequest(req.R.Method, url+req.R.URL.Path, req.R.Body) | |
| if err != nil { | |
| log.Printf("Error creating request: %v", err) | |
| http.Error(req.W, "Error creating request", http.StatusInternalServerError) | |
| return | |
| } | |
| // Copy headers from original request | |
| for name, values := range req.R.Header { | |
| for _, value := range values { | |
| targetReq.Header.Add(name, value) | |
| } | |
| } | |
| // Create HTTP client | |
| client := &http.Client{} | |
| // Send the request to the target server | |
| resp, err := client.Do(targetReq) | |
| if err != nil { | |
| log.Printf("Error forwarding request: %v", err) | |
| http.Error(req.W, "Error forwarding request", http.StatusBadGateway) | |
| return | |
| } | |
| defer resp.Body.Close() | |
| // Check if this is an SSE response | |
| isSSE := false | |
| for name, values := range resp.Header { | |
| for _, value := range values { | |
| req.W.Header().Add(name, value) | |
| if name == "Content-Type" && value == "text/event-stream" { | |
| isSSE = true | |
| } | |
| } | |
| } | |
| // Set response status code | |
| req.W.WriteHeader(resp.StatusCode) | |
| // Handle SSE responses differently | |
| if isSSE { | |
| // Set necessary headers for SSE | |
| req.W.Header().Set("Content-Type", "text/event-stream") | |
| req.W.Header().Set("Cache-Control", "no-cache") | |
| req.W.Header().Set("Connection", "keep-alive") | |
| req.W.Header().Set("Transfer-Encoding", "chunked") | |
| // Create a flusher if the ResponseWriter supports it | |
| flusher, ok := req.W.(http.Flusher) | |
| if !ok { | |
| log.Printf("ResponseWriter does not support flushing") | |
| http.Error(req.W, "Streaming unsupported", http.StatusInternalServerError) | |
| return | |
| } | |
| // Buffer for reading from response body | |
| buf := make([]byte, 1024) | |
| for { | |
| n, err := resp.Body.Read(buf) | |
| if n > 0 { | |
| // Write data to client | |
| if _, writeErr := req.W.Write(buf[:n]); writeErr != nil { | |
| log.Printf("Error writing to client: %v", writeErr) | |
| break | |
| } | |
| // Flush data immediately to client | |
| flusher.Flush() | |
| } | |
| if err != nil { | |
| if err != io.EOF { | |
| log.Printf("Error reading from response body: %v", err) | |
| } | |
| break | |
| } | |
| } | |
| } else { | |
| // For non-SSE responses, just copy the body | |
| if _, err := io.Copy(req.W, resp.Body); err != nil { | |
| log.Printf("Error copying response body: %v", err) | |
| } | |
| } | |
| } | |