p2p-llm / proxy /handler.go
arpinfidel's picture
temp
48511d8
package proxy
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"io"
"log"
"net/http"
"github.com/arpinfidel/p2p-llm/config"
"github.com/arpinfidel/p2p-llm/db"
)
type ProxyHandler struct {
cfg *config.Config
peerRepo db.PeerRepository
queue chan Request
maxParallelRequests int
maxParallelPeerRequests int
}
type Request struct {
W http.ResponseWriter
R *http.Request
}
func NewProxyHandler(cfg *config.Config, peerRepo db.PeerRepository) *ProxyHandler {
return &ProxyHandler{
cfg: cfg,
peerRepo: peerRepo,
queue: make(chan Request, 100), // Hardcoded queue size
maxParallelRequests: cfg.MaxParallelRequests,
maxParallelPeerRequests: 5, // Hardcoded peer requests limit
}
}
func (h *ProxyHandler) Handle(w http.ResponseWriter, r *http.Request) {
req := Request{W: w, R: r}
h.queue <- req
}
func (h *ProxyHandler) Run() {
// Get peers from config
peers := h.cfg.TrustedPeers
// Get peers from database
dbPeers, err := h.peerRepo.ListTrustedPeers()
if err != nil {
log.Printf("Error getting peers from database: %v", err)
} else {
for _, p := range dbPeers {
block, _ := pem.Decode([]byte(p.PublicKey))
if block == nil {
continue
}
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
continue
}
peers = append(peers, config.Peer{
URL: p.URL,
PublicKey: pubKey.(*rsa.PublicKey),
})
}
}
// Start workers for target URL
for range h.maxParallelRequests {
go func() {
for req := range h.queue {
h.Forward(req, h.cfg.TargetURL)
}
}()
}
// Start workers for each peer
for _, peer := range peers {
go func(p config.Peer) {
for req := range h.queue {
h.Forward(req, p.URL)
}
}(peer)
}
}
func (h *ProxyHandler) Forward(req Request, url string) {
// Create a new request to the target server
targetReq, err := http.NewRequest(req.R.Method, url+req.R.URL.Path, req.R.Body)
if err != nil {
log.Printf("Error creating request: %v", err)
http.Error(req.W, "Error creating request", http.StatusInternalServerError)
return
}
// Copy headers from original request
for name, values := range req.R.Header {
for _, value := range values {
targetReq.Header.Add(name, value)
}
}
// Create HTTP client
client := &http.Client{}
// Send the request to the target server
resp, err := client.Do(targetReq)
if err != nil {
log.Printf("Error forwarding request: %v", err)
http.Error(req.W, "Error forwarding request", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// Check if this is an SSE response
isSSE := false
for name, values := range resp.Header {
for _, value := range values {
req.W.Header().Add(name, value)
if name == "Content-Type" && value == "text/event-stream" {
isSSE = true
}
}
}
// Set response status code
req.W.WriteHeader(resp.StatusCode)
// Handle SSE responses differently
if isSSE {
// Set necessary headers for SSE
req.W.Header().Set("Content-Type", "text/event-stream")
req.W.Header().Set("Cache-Control", "no-cache")
req.W.Header().Set("Connection", "keep-alive")
req.W.Header().Set("Transfer-Encoding", "chunked")
// Create a flusher if the ResponseWriter supports it
flusher, ok := req.W.(http.Flusher)
if !ok {
log.Printf("ResponseWriter does not support flushing")
http.Error(req.W, "Streaming unsupported", http.StatusInternalServerError)
return
}
// Buffer for reading from response body
buf := make([]byte, 1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
// Write data to client
if _, writeErr := req.W.Write(buf[:n]); writeErr != nil {
log.Printf("Error writing to client: %v", writeErr)
break
}
// Flush data immediately to client
flusher.Flush()
}
if err != nil {
if err != io.EOF {
log.Printf("Error reading from response body: %v", err)
}
break
}
}
} else {
// For non-SSE responses, just copy the body
if _, err := io.Copy(req.W, resp.Body); err != nil {
log.Printf("Error copying response body: %v", err)
}
}
}