| package types
|
|
|
| import (
|
| "encoding/json"
|
| "errors"
|
| "fmt"
|
| "net/http"
|
| "strings"
|
|
|
| "github.com/QuantumNous/new-api/common"
|
| )
|
|
|
| type OpenAIError struct {
|
| Message string `json:"message"`
|
| Type string `json:"type"`
|
| Param string `json:"param"`
|
| Code any `json:"code"`
|
| Metadata json.RawMessage `json:"metadata,omitempty"`
|
| }
|
|
|
| type ClaudeError struct {
|
| Type string `json:"type,omitempty"`
|
| Message string `json:"message,omitempty"`
|
| }
|
|
|
| type ErrorType string
|
|
|
| const (
|
| ErrorTypeNewAPIError ErrorType = "new_api_error"
|
| ErrorTypeOpenAIError ErrorType = "openai_error"
|
| ErrorTypeClaudeError ErrorType = "claude_error"
|
| ErrorTypeMidjourneyError ErrorType = "midjourney_error"
|
| ErrorTypeGeminiError ErrorType = "gemini_error"
|
| ErrorTypeRerankError ErrorType = "rerank_error"
|
| ErrorTypeUpstreamError ErrorType = "upstream_error"
|
| )
|
|
|
| type ErrorCode string
|
|
|
| const (
|
| ErrorCodeInvalidRequest ErrorCode = "invalid_request"
|
| ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected"
|
| ErrorCodeViolationFeeGrokCSAM ErrorCode = "violation_fee.grok.csam"
|
|
|
|
|
| ErrorCodeCountTokenFailed ErrorCode = "count_token_failed"
|
| ErrorCodeModelPriceError ErrorCode = "model_price_error"
|
| ErrorCodeInvalidApiType ErrorCode = "invalid_api_type"
|
| ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed"
|
| ErrorCodeDoRequestFailed ErrorCode = "do_request_failed"
|
| ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed"
|
| ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed"
|
|
|
|
|
| ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key"
|
| ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid"
|
| ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid"
|
| ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error"
|
| ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error"
|
| ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key"
|
| ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded"
|
|
|
|
|
| ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed"
|
| ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed"
|
| ErrorCodeAccessDenied ErrorCode = "access_denied"
|
|
|
|
|
| ErrorCodeBadRequestBody ErrorCode = "bad_request_body"
|
|
|
|
|
| ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed"
|
| ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"
|
| ErrorCodeBadResponse ErrorCode = "bad_response"
|
| ErrorCodeBadResponseBody ErrorCode = "bad_response_body"
|
| ErrorCodeEmptyResponse ErrorCode = "empty_response"
|
| ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error"
|
| ErrorCodeModelNotFound ErrorCode = "model_not_found"
|
| ErrorCodePromptBlocked ErrorCode = "prompt_blocked"
|
|
|
|
|
| ErrorCodeQueryDataError ErrorCode = "query_data_error"
|
| ErrorCodeUpdateDataError ErrorCode = "update_data_error"
|
|
|
|
|
| ErrorCodeInsufficientUserQuota ErrorCode = "insufficient_user_quota"
|
| ErrorCodePreConsumeTokenQuotaFailed ErrorCode = "pre_consume_token_quota_failed"
|
| )
|
|
|
| type NewAPIError struct {
|
| Err error
|
| RelayError any
|
| skipRetry bool
|
| recordErrorLog *bool
|
| errorType ErrorType
|
| errorCode ErrorCode
|
| StatusCode int
|
| Metadata json.RawMessage
|
| }
|
|
|
|
|
| func (e *NewAPIError) Unwrap() error {
|
| if e == nil {
|
| return nil
|
| }
|
| return e.Err
|
| }
|
|
|
| func (e *NewAPIError) GetErrorCode() ErrorCode {
|
| if e == nil {
|
| return ""
|
| }
|
| return e.errorCode
|
| }
|
|
|
| func (e *NewAPIError) GetErrorType() ErrorType {
|
| if e == nil {
|
| return ""
|
| }
|
| return e.errorType
|
| }
|
|
|
| func (e *NewAPIError) Error() string {
|
| if e == nil {
|
| return ""
|
| }
|
| if e.Err == nil {
|
|
|
| return string(e.errorCode)
|
| }
|
| return e.Err.Error()
|
| }
|
|
|
| func (e *NewAPIError) ErrorWithStatusCode() string {
|
| if e == nil {
|
| return ""
|
| }
|
| msg := e.Error()
|
| if e.StatusCode == 0 {
|
| return msg
|
| }
|
| if msg == "" {
|
| return fmt.Sprintf("status_code=%d", e.StatusCode)
|
| }
|
| return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg)
|
| }
|
|
|
| func (e *NewAPIError) MaskSensitiveError() string {
|
| if e == nil {
|
| return ""
|
| }
|
| if e.Err == nil {
|
| return string(e.errorCode)
|
| }
|
| errStr := e.Err.Error()
|
| if e.errorCode == ErrorCodeCountTokenFailed {
|
| return errStr
|
| }
|
| return common.MaskSensitiveInfo(errStr)
|
| }
|
|
|
| func (e *NewAPIError) MaskSensitiveErrorWithStatusCode() string {
|
| if e == nil {
|
| return ""
|
| }
|
| msg := e.MaskSensitiveError()
|
| if e.StatusCode == 0 {
|
| return msg
|
| }
|
| if msg == "" {
|
| return fmt.Sprintf("status_code=%d", e.StatusCode)
|
| }
|
| return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg)
|
| }
|
|
|
| func (e *NewAPIError) SetMessage(message string) {
|
| e.Err = errors.New(message)
|
| }
|
|
|
| func (e *NewAPIError) ToOpenAIError() OpenAIError {
|
| var result OpenAIError
|
| switch e.errorType {
|
| case ErrorTypeOpenAIError:
|
| if openAIError, ok := e.RelayError.(OpenAIError); ok {
|
| result = openAIError
|
| }
|
| case ErrorTypeClaudeError:
|
| if claudeError, ok := e.RelayError.(ClaudeError); ok {
|
| result = OpenAIError{
|
| Message: e.Error(),
|
| Type: claudeError.Type,
|
| Param: "",
|
| Code: e.errorCode,
|
| }
|
| }
|
| default:
|
| result = OpenAIError{
|
| Message: e.Error(),
|
| Type: string(e.errorType),
|
| Param: "",
|
| Code: e.errorCode,
|
| }
|
| }
|
| if e.errorCode != ErrorCodeCountTokenFailed {
|
| result.Message = common.MaskSensitiveInfo(result.Message)
|
| }
|
| if result.Message == "" {
|
| result.Message = string(e.errorType)
|
| }
|
| return result
|
| }
|
|
|
| func (e *NewAPIError) ToClaudeError() ClaudeError {
|
| var result ClaudeError
|
| switch e.errorType {
|
| case ErrorTypeOpenAIError:
|
| if openAIError, ok := e.RelayError.(OpenAIError); ok {
|
| result = ClaudeError{
|
| Message: e.Error(),
|
| Type: fmt.Sprintf("%v", openAIError.Code),
|
| }
|
| }
|
| case ErrorTypeClaudeError:
|
| if claudeError, ok := e.RelayError.(ClaudeError); ok {
|
| result = claudeError
|
| }
|
| default:
|
| result = ClaudeError{
|
| Message: e.Error(),
|
| Type: string(e.errorType),
|
| }
|
| }
|
| if e.errorCode != ErrorCodeCountTokenFailed {
|
| result.Message = common.MaskSensitiveInfo(result.Message)
|
| }
|
| if result.Message == "" {
|
| result.Message = string(e.errorType)
|
| }
|
| return result
|
| }
|
|
|
| type NewAPIErrorOptions func(*NewAPIError)
|
|
|
| func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
|
| var newErr *NewAPIError
|
|
|
| if errors.As(err, &newErr) {
|
| for _, op := range ops {
|
| op(newErr)
|
| }
|
| return newErr
|
| }
|
| e := &NewAPIError{
|
| Err: err,
|
| RelayError: nil,
|
| errorType: ErrorTypeNewAPIError,
|
| StatusCode: http.StatusInternalServerError,
|
| errorCode: errorCode,
|
| }
|
| for _, op := range ops {
|
| op(e)
|
| }
|
| return e
|
| }
|
|
|
| func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
| var newErr *NewAPIError
|
|
|
| if errors.As(err, &newErr) {
|
| if newErr.RelayError == nil {
|
| openaiError := OpenAIError{
|
| Message: newErr.Error(),
|
| Type: string(errorCode),
|
| Code: errorCode,
|
| }
|
| newErr.RelayError = openaiError
|
| }
|
| for _, op := range ops {
|
| op(newErr)
|
| }
|
| return newErr
|
| }
|
| openaiError := OpenAIError{
|
| Message: err.Error(),
|
| Type: string(errorCode),
|
| Code: errorCode,
|
| }
|
| return WithOpenAIError(openaiError, statusCode, ops...)
|
| }
|
|
|
| func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
| openaiError := OpenAIError{
|
| Type: string(errorCode),
|
| Code: errorCode,
|
| }
|
| return WithOpenAIError(openaiError, statusCode, ops...)
|
| }
|
|
|
| func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
| e := &NewAPIError{
|
| Err: err,
|
| RelayError: OpenAIError{
|
| Message: err.Error(),
|
| Type: string(errorCode),
|
| },
|
| errorType: ErrorTypeNewAPIError,
|
| StatusCode: statusCode,
|
| errorCode: errorCode,
|
| }
|
| for _, op := range ops {
|
| op(e)
|
| }
|
|
|
| return e
|
| }
|
|
|
| func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
| code, ok := openAIError.Code.(string)
|
| if !ok {
|
| if openAIError.Code != nil {
|
| code = fmt.Sprintf("%v", openAIError.Code)
|
| } else {
|
| code = "unknown_error"
|
| }
|
| }
|
| if openAIError.Type == "" {
|
| openAIError.Type = "upstream_error"
|
| }
|
| e := &NewAPIError{
|
| RelayError: openAIError,
|
| errorType: ErrorTypeOpenAIError,
|
| StatusCode: statusCode,
|
| Err: errors.New(openAIError.Message),
|
| errorCode: ErrorCode(code),
|
| }
|
|
|
| if len(openAIError.Metadata) > 0 {
|
| openAIError.Message = fmt.Sprintf("%s (%s)", openAIError.Message, openAIError.Metadata)
|
| e.Metadata = openAIError.Metadata
|
| e.RelayError = openAIError
|
| e.Err = errors.New(openAIError.Message)
|
| }
|
| for _, op := range ops {
|
| op(e)
|
| }
|
| return e
|
| }
|
|
|
| func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
| if claudeError.Type == "" {
|
| claudeError.Type = "upstream_error"
|
| }
|
| e := &NewAPIError{
|
| RelayError: claudeError,
|
| errorType: ErrorTypeClaudeError,
|
| StatusCode: statusCode,
|
| Err: errors.New(claudeError.Message),
|
| errorCode: ErrorCode(claudeError.Type),
|
| }
|
| for _, op := range ops {
|
| op(e)
|
| }
|
| return e
|
| }
|
|
|
| func IsChannelError(err *NewAPIError) bool {
|
| if err == nil {
|
| return false
|
| }
|
| return strings.HasPrefix(string(err.errorCode), "channel:")
|
| }
|
|
|
| func IsSkipRetryError(err *NewAPIError) bool {
|
| if err == nil {
|
| return false
|
| }
|
|
|
| return err.skipRetry
|
| }
|
|
|
| func ErrOptionWithSkipRetry() NewAPIErrorOptions {
|
| return func(e *NewAPIError) {
|
| e.skipRetry = true
|
| }
|
| }
|
|
|
| func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
|
| return func(e *NewAPIError) {
|
| e.recordErrorLog = common.GetPointer(false)
|
| }
|
| }
|
|
|
| func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions {
|
| return func(e *NewAPIError) {
|
| if common.DebugEnabled {
|
| fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err)
|
| }
|
| e.Err = errors.New(replaceStr)
|
| }
|
| }
|
|
|
| func IsRecordErrorLog(e *NewAPIError) bool {
|
| if e == nil {
|
| return false
|
| }
|
| if e.recordErrorLog == nil {
|
|
|
| return true
|
| }
|
| return *e.recordErrorLog
|
| }
|
|
|