Spaces:
Runtime error
Runtime error
package proxy | |
import ( | |
"crypto/rsa" | |
"crypto/x509" | |
"encoding/pem" | |
"io" | |
"log" | |
"net/http" | |
"github.com/arpinfidel/p2p-llm/config" | |
"github.com/arpinfidel/p2p-llm/db" | |
) | |
type ProxyHandler struct { | |
cfg *config.Config | |
peerRepo db.PeerRepository | |
queue chan Request | |
maxParallelRequests int | |
maxParallelPeerRequests int | |
} | |
type Request struct { | |
W http.ResponseWriter | |
R *http.Request | |
} | |
func NewProxyHandler(cfg *config.Config, peerRepo db.PeerRepository) *ProxyHandler { | |
return &ProxyHandler{ | |
cfg: cfg, | |
peerRepo: peerRepo, | |
queue: make(chan Request, 100), // Hardcoded queue size | |
maxParallelRequests: cfg.MaxParallelRequests, | |
maxParallelPeerRequests: 5, // Hardcoded peer requests limit | |
} | |
} | |
func (h *ProxyHandler) Handle(w http.ResponseWriter, r *http.Request) { | |
req := Request{W: w, R: r} | |
h.queue <- req | |
} | |
func (h *ProxyHandler) Run() { | |
// Get peers from config | |
peers := h.cfg.TrustedPeers | |
// Get peers from database | |
dbPeers, err := h.peerRepo.ListTrustedPeers() | |
if err != nil { | |
log.Printf("Error getting peers from database: %v", err) | |
} else { | |
for _, p := range dbPeers { | |
block, _ := pem.Decode([]byte(p.PublicKey)) | |
if block == nil { | |
continue | |
} | |
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) | |
if err != nil { | |
continue | |
} | |
peers = append(peers, config.Peer{ | |
URL: p.URL, | |
PublicKey: pubKey.(*rsa.PublicKey), | |
}) | |
} | |
} | |
// Start workers for target URL | |
for range h.maxParallelRequests { | |
go func() { | |
for req := range h.queue { | |
h.Forward(req, h.cfg.TargetURL) | |
} | |
}() | |
} | |
// Start workers for each peer | |
for _, peer := range peers { | |
go func(p config.Peer) { | |
for req := range h.queue { | |
h.Forward(req, p.URL) | |
} | |
}(peer) | |
} | |
} | |
func (h *ProxyHandler) Forward(req Request, url string) { | |
// Create a new request to the target server | |
targetReq, err := http.NewRequest(req.R.Method, url+req.R.URL.Path, req.R.Body) | |
if err != nil { | |
log.Printf("Error creating request: %v", err) | |
http.Error(req.W, "Error creating request", http.StatusInternalServerError) | |
return | |
} | |
// Copy headers from original request | |
for name, values := range req.R.Header { | |
for _, value := range values { | |
targetReq.Header.Add(name, value) | |
} | |
} | |
// Create HTTP client | |
client := &http.Client{} | |
// Send the request to the target server | |
resp, err := client.Do(targetReq) | |
if err != nil { | |
log.Printf("Error forwarding request: %v", err) | |
http.Error(req.W, "Error forwarding request", http.StatusBadGateway) | |
return | |
} | |
defer resp.Body.Close() | |
// Check if this is an SSE response | |
isSSE := false | |
for name, values := range resp.Header { | |
for _, value := range values { | |
req.W.Header().Add(name, value) | |
if name == "Content-Type" && value == "text/event-stream" { | |
isSSE = true | |
} | |
} | |
} | |
// Set response status code | |
req.W.WriteHeader(resp.StatusCode) | |
// Handle SSE responses differently | |
if isSSE { | |
// Set necessary headers for SSE | |
req.W.Header().Set("Content-Type", "text/event-stream") | |
req.W.Header().Set("Cache-Control", "no-cache") | |
req.W.Header().Set("Connection", "keep-alive") | |
req.W.Header().Set("Transfer-Encoding", "chunked") | |
// Create a flusher if the ResponseWriter supports it | |
flusher, ok := req.W.(http.Flusher) | |
if !ok { | |
log.Printf("ResponseWriter does not support flushing") | |
http.Error(req.W, "Streaming unsupported", http.StatusInternalServerError) | |
return | |
} | |
// Buffer for reading from response body | |
buf := make([]byte, 1024) | |
for { | |
n, err := resp.Body.Read(buf) | |
if n > 0 { | |
// Write data to client | |
if _, writeErr := req.W.Write(buf[:n]); writeErr != nil { | |
log.Printf("Error writing to client: %v", writeErr) | |
break | |
} | |
// Flush data immediately to client | |
flusher.Flush() | |
} | |
if err != nil { | |
if err != io.EOF { | |
log.Printf("Error reading from response body: %v", err) | |
} | |
break | |
} | |
} | |
} else { | |
// For non-SSE responses, just copy the body | |
if _, err := io.Copy(req.W, resp.Body); err != nil { | |
log.Printf("Error copying response body: %v", err) | |
} | |
} | |
} | |