| | package middleware |
| |
|
| | import ( |
| | "crypto/subtle" |
| | "errors" |
| | "net/http" |
| | "strings" |
| |
|
| | "github.com/labstack/echo/v4" |
| | "github.com/labstack/echo/v4/middleware" |
| | "github.com/mudler/LocalAI/core/config" |
| | "github.com/mudler/LocalAI/core/schema" |
| | ) |
| |
|
| | var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key") |
| |
|
| | |
| | func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) { |
| | |
| | validator := getApiKeyValidationFunction(applicationConfig) |
| |
|
| | |
| | errorHandler := getApiKeyErrorHandler(applicationConfig) |
| |
|
| | |
| | skipper := getApiKeyRequiredFilterFunction(applicationConfig) |
| |
|
| | |
| | return func(next echo.HandlerFunc) echo.HandlerFunc { |
| | return func(c echo.Context) error { |
| | if len(applicationConfig.ApiKeys) == 0 { |
| | return next(c) |
| | } |
| |
|
| | |
| | if skipper != nil && skipper(c) { |
| | return next(c) |
| | } |
| |
|
| | |
| | key, err := extractKeyFromMultipleSources(c) |
| | if err != nil { |
| | return errorHandler(err, c) |
| | } |
| |
|
| | |
| | valid, err := validator(key, c) |
| | if err != nil || !valid { |
| | return errorHandler(ErrMissingOrMalformedAPIKey, c) |
| | } |
| |
|
| | |
| | c.Set("api_key", key) |
| |
|
| | return next(c) |
| | } |
| | }, nil |
| | } |
| |
|
| | |
| | |
| | func extractKeyFromMultipleSources(c echo.Context) (string, error) { |
| | |
| | auth := c.Request().Header.Get("Authorization") |
| | if auth != "" { |
| | |
| | if strings.HasPrefix(auth, "Bearer ") { |
| | return strings.TrimPrefix(auth, "Bearer "), nil |
| | } |
| | |
| | return auth, nil |
| | } |
| |
|
| | |
| | if key := c.Request().Header.Get("x-api-key"); key != "" { |
| | return key, nil |
| | } |
| |
|
| | |
| | if key := c.Request().Header.Get("xi-api-key"); key != "" { |
| | return key, nil |
| | } |
| |
|
| | |
| | cookie, err := c.Cookie("token") |
| | if err == nil && cookie != nil && cookie.Value != "" { |
| | return cookie.Value, nil |
| | } |
| |
|
| | return "", ErrMissingOrMalformedAPIKey |
| | } |
| |
|
| | func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error { |
| | return func(err error, c echo.Context) error { |
| | if errors.Is(err, ErrMissingOrMalformedAPIKey) { |
| | if len(applicationConfig.ApiKeys) == 0 { |
| | return nil |
| | } |
| | c.Response().Header().Set("WWW-Authenticate", "Bearer") |
| | if applicationConfig.OpaqueErrors { |
| | return c.NoContent(http.StatusUnauthorized) |
| | } |
| |
|
| | |
| | contentType := c.Request().Header.Get("Content-Type") |
| | if strings.Contains(contentType, "application/json") { |
| | return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ |
| | Error: &schema.APIError{ |
| | Message: "An authentication key is required", |
| | Code: 401, |
| | Type: "invalid_request_error", |
| | }, |
| | }) |
| | } |
| |
|
| | return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{ |
| | "BaseURL": BaseURL(c), |
| | }) |
| | } |
| | if applicationConfig.OpaqueErrors { |
| | return c.NoContent(http.StatusInternalServerError) |
| | } |
| | return err |
| | } |
| | } |
| |
|
| | func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) { |
| | if applicationConfig.UseSubtleKeyComparison { |
| | return func(key string, c echo.Context) (bool, error) { |
| | if len(applicationConfig.ApiKeys) == 0 { |
| | return true, nil |
| | } |
| | for _, validKey := range applicationConfig.ApiKeys { |
| | if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { |
| | return true, nil |
| | } |
| | } |
| | return false, ErrMissingOrMalformedAPIKey |
| | } |
| | } |
| |
|
| | return func(key string, c echo.Context) (bool, error) { |
| | if len(applicationConfig.ApiKeys) == 0 { |
| | return true, nil |
| | } |
| | for _, validKey := range applicationConfig.ApiKeys { |
| | if key == validKey { |
| | return true, nil |
| | } |
| | } |
| | return false, ErrMissingOrMalformedAPIKey |
| | } |
| | } |
| |
|
| | func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper { |
| | return func(c echo.Context) bool { |
| | path := c.Request().URL.Path |
| |
|
| | for _, p := range applicationConfig.PathWithoutAuth { |
| | if strings.HasPrefix(path, p) { |
| | return true |
| | } |
| | } |
| |
|
| | |
| | if applicationConfig.DisableApiKeyRequirementForHttpGet { |
| | if c.Request().Method != http.MethodGet { |
| | return false |
| | } |
| | for _, rx := range applicationConfig.HttpGetExemptedEndpoints { |
| | if rx.MatchString(c.Path()) { |
| | return true |
| | } |
| | } |
| | } |
| |
|
| | return false |
| | } |
| | } |
| |
|