| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| package httpclient |
|
|
| import ( |
| "fmt" |
| "net" |
| "net/http" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" |
| "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" |
| ) |
|
|
| |
| const ( |
| defaultMaxIdleConns = 100 |
| defaultMaxIdleConnsPerHost = 10 |
| defaultIdleConnTimeout = 90 * time.Second |
| defaultDialTimeout = 5 * time.Second |
| defaultTLSHandshakeTimeout = 5 * time.Second |
| validatedHostTTL = 30 * time.Second |
| ) |
|
|
| |
| type Options struct { |
| ProxyURL string |
| Timeout time.Duration |
| ResponseHeaderTimeout time.Duration |
| InsecureSkipVerify bool |
| ValidateResolvedIP bool |
| AllowPrivateHosts bool |
|
|
| |
| MaxIdleConns int |
| MaxIdleConnsPerHost int |
| MaxConnsPerHost int |
| } |
|
|
| |
| var sharedClients sync.Map |
|
|
| |
| var validateResolvedIP = urlvalidator.ValidateResolvedIP |
|
|
| |
| |
| |
| func GetClient(opts Options) (*http.Client, error) { |
| key := buildClientKey(opts) |
| if cached, ok := sharedClients.Load(key); ok { |
| if client, ok := cached.(*http.Client); ok { |
| return client, nil |
| } |
| } |
|
|
| client, err := buildClient(opts) |
| if err != nil { |
| return nil, err |
| } |
|
|
| actual, _ := sharedClients.LoadOrStore(key, client) |
| if c, ok := actual.(*http.Client); ok { |
| return c, nil |
| } |
| return client, nil |
| } |
|
|
| func buildClient(opts Options) (*http.Client, error) { |
| transport, err := buildTransport(opts) |
| if err != nil { |
| return nil, err |
| } |
|
|
| var rt http.RoundTripper = transport |
| if opts.ValidateResolvedIP && !opts.AllowPrivateHosts { |
| rt = newValidatedTransport(transport) |
| } |
| return &http.Client{ |
| Transport: rt, |
| Timeout: opts.Timeout, |
| }, nil |
| } |
|
|
| func buildTransport(opts Options) (*http.Transport, error) { |
| |
| maxIdleConns := opts.MaxIdleConns |
| if maxIdleConns <= 0 { |
| maxIdleConns = defaultMaxIdleConns |
| } |
| maxIdleConnsPerHost := opts.MaxIdleConnsPerHost |
| if maxIdleConnsPerHost <= 0 { |
| maxIdleConnsPerHost = defaultMaxIdleConnsPerHost |
| } |
|
|
| transport := &http.Transport{ |
| DialContext: (&net.Dialer{ |
| Timeout: defaultDialTimeout, |
| }).DialContext, |
| TLSHandshakeTimeout: defaultTLSHandshakeTimeout, |
| MaxIdleConns: maxIdleConns, |
| MaxIdleConnsPerHost: maxIdleConnsPerHost, |
| MaxConnsPerHost: opts.MaxConnsPerHost, |
| IdleConnTimeout: defaultIdleConnTimeout, |
| ResponseHeaderTimeout: opts.ResponseHeaderTimeout, |
| } |
|
|
| if opts.InsecureSkipVerify { |
| |
| return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead") |
| } |
|
|
| _, parsed, err := proxyurl.Parse(opts.ProxyURL) |
| if err != nil { |
| return nil, err |
| } |
| if parsed == nil { |
| return transport, nil |
| } |
|
|
| if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { |
| return nil, err |
| } |
|
|
| return transport, nil |
| } |
|
|
| func buildClientKey(opts Options) string { |
| return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d", |
| strings.TrimSpace(opts.ProxyURL), |
| opts.Timeout.String(), |
| opts.ResponseHeaderTimeout.String(), |
| opts.InsecureSkipVerify, |
| opts.ValidateResolvedIP, |
| opts.AllowPrivateHosts, |
| opts.MaxIdleConns, |
| opts.MaxIdleConnsPerHost, |
| opts.MaxConnsPerHost, |
| ) |
| } |
|
|
| type validatedTransport struct { |
| base http.RoundTripper |
| validatedHosts sync.Map |
| now func() time.Time |
| } |
|
|
| func newValidatedTransport(base http.RoundTripper) *validatedTransport { |
| return &validatedTransport{ |
| base: base, |
| now: time.Now, |
| } |
| } |
|
|
| func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool { |
| if t == nil { |
| return false |
| } |
| raw, ok := t.validatedHosts.Load(host) |
| if !ok { |
| return false |
| } |
| expireAt, ok := raw.(time.Time) |
| if !ok { |
| t.validatedHosts.Delete(host) |
| return false |
| } |
| if now.Before(expireAt) { |
| return true |
| } |
| t.validatedHosts.Delete(host) |
| return false |
| } |
|
|
| func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
| if req != nil && req.URL != nil { |
| host := strings.ToLower(strings.TrimSpace(req.URL.Hostname())) |
| if host != "" { |
| now := time.Now() |
| if t != nil && t.now != nil { |
| now = t.now() |
| } |
| if !t.isValidatedHost(host, now) { |
| if err := validateResolvedIP(host); err != nil { |
| return nil, err |
| } |
| t.validatedHosts.Store(host, now.Add(validatedHostTTL)) |
| } |
| } |
| } |
| if t == nil || t.base == nil { |
| return nil, fmt.Errorf("validated transport base is nil") |
| } |
| return t.base.RoundTrip(req) |
| } |
|
|