| package controller |
|
|
| import ( |
| "bytes" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io" |
| "math" |
| "net/http" |
| "net/http/httptest" |
| "net/url" |
| "strconv" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/middleware" |
| "github.com/QuantumNous/new-api/model" |
| "github.com/QuantumNous/new-api/relay" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| relayconstant "github.com/QuantumNous/new-api/relay/constant" |
| "github.com/QuantumNous/new-api/relay/helper" |
| "github.com/QuantumNous/new-api/service" |
| "github.com/QuantumNous/new-api/setting/operation_setting" |
| "github.com/QuantumNous/new-api/types" |
|
|
| "github.com/bytedance/gopkg/util/gopool" |
| "github.com/samber/lo" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| type testResult struct { |
| context *gin.Context |
| localErr error |
| newAPIError *types.NewAPIError |
| } |
|
|
| func testChannel(channel *model.Channel, testModel string, endpointType string) testResult { |
| tik := time.Now() |
| var unsupportedTestChannelTypes = []int{ |
| constant.ChannelTypeMidjourney, |
| constant.ChannelTypeMidjourneyPlus, |
| constant.ChannelTypeSunoAPI, |
| constant.ChannelTypeKling, |
| constant.ChannelTypeJimeng, |
| constant.ChannelTypeDoubaoVideo, |
| constant.ChannelTypeVidu, |
| } |
| if lo.Contains(unsupportedTestChannelTypes, channel.Type) { |
| channelTypeName := constant.GetChannelTypeName(channel.Type) |
| return testResult{ |
| localErr: fmt.Errorf("%s channel test is not supported", channelTypeName), |
| } |
| } |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| testModel = strings.TrimSpace(testModel) |
| if testModel == "" { |
| if channel.TestModel != nil && *channel.TestModel != "" { |
| testModel = strings.TrimSpace(*channel.TestModel) |
| } else { |
| models := channel.GetModels() |
| if len(models) > 0 { |
| testModel = strings.TrimSpace(models[0]) |
| } |
| if testModel == "" { |
| testModel = "gpt-4o-mini" |
| } |
| } |
| } |
|
|
| requestPath := "/v1/chat/completions" |
|
|
| |
| if endpointType != "" { |
| if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok { |
| requestPath = endpointInfo.Path |
| } |
| } else { |
| |
| |
| if strings.Contains(strings.ToLower(testModel), "embedding") || |
| strings.HasPrefix(testModel, "m3e") || |
| strings.Contains(testModel, "bge-") || |
| strings.Contains(testModel, "embed") || |
| channel.Type == constant.ChannelTypeMokaAI { |
| requestPath = "/v1/embeddings" |
| } |
|
|
| |
| if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { |
| requestPath = "/v1/images/generations" |
| } |
| } |
|
|
| c.Request = &http.Request{ |
| Method: "POST", |
| URL: &url.URL{Path: requestPath}, |
| Body: nil, |
| Header: make(http.Header), |
| } |
|
|
| cache, err := model.GetUserCache(1) |
| if err != nil { |
| return testResult{ |
| localErr: err, |
| newAPIError: nil, |
| } |
| } |
| cache.WriteContext(c) |
|
|
| |
| c.Request.Header.Set("Content-Type", "application/json") |
| c.Set("channel", channel.Type) |
| c.Set("base_url", channel.GetBaseURL()) |
| group, _ := model.GetUserGroup(1, false) |
| c.Set("group", group) |
|
|
| newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel) |
| if newAPIError != nil { |
| return testResult{ |
| context: c, |
| localErr: newAPIError, |
| newAPIError: newAPIError, |
| } |
| } |
|
|
| |
| var relayFormat types.RelayFormat |
| if endpointType != "" { |
| |
| switch constant.EndpointType(endpointType) { |
| case constant.EndpointTypeOpenAI: |
| relayFormat = types.RelayFormatOpenAI |
| case constant.EndpointTypeOpenAIResponse: |
| relayFormat = types.RelayFormatOpenAIResponses |
| case constant.EndpointTypeAnthropic: |
| relayFormat = types.RelayFormatClaude |
| case constant.EndpointTypeGemini: |
| relayFormat = types.RelayFormatGemini |
| case constant.EndpointTypeJinaRerank: |
| relayFormat = types.RelayFormatRerank |
| case constant.EndpointTypeImageGeneration: |
| relayFormat = types.RelayFormatOpenAIImage |
| case constant.EndpointTypeEmbeddings: |
| relayFormat = types.RelayFormatEmbedding |
| default: |
| relayFormat = types.RelayFormatOpenAI |
| } |
| } else { |
| |
| relayFormat = types.RelayFormatOpenAI |
| if c.Request.URL.Path == "/v1/embeddings" { |
| relayFormat = types.RelayFormatEmbedding |
| } |
| if c.Request.URL.Path == "/v1/images/generations" { |
| relayFormat = types.RelayFormatOpenAIImage |
| } |
| if c.Request.URL.Path == "/v1/messages" { |
| relayFormat = types.RelayFormatClaude |
| } |
| if strings.Contains(c.Request.URL.Path, "/v1beta/models") { |
| relayFormat = types.RelayFormatGemini |
| } |
| if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" { |
| relayFormat = types.RelayFormatRerank |
| } |
| if c.Request.URL.Path == "/v1/responses" { |
| relayFormat = types.RelayFormatOpenAIResponses |
| } |
| } |
|
|
| request := buildTestRequest(testModel, endpointType) |
|
|
| info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) |
|
|
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed), |
| } |
| } |
|
|
| info.InitChannelMeta(c) |
|
|
| err = helper.ModelMappedHelper(c, info, request) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError), |
| } |
| } |
|
|
| testModel = info.UpstreamModelName |
| |
| request.SetModelName(testModel) |
|
|
| apiType, _ := common.ChannelType2APIType(channel.Type) |
| adaptor := relay.GetAdaptor(apiType) |
| if adaptor == nil { |
| return testResult{ |
| context: c, |
| localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), |
| newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType), |
| } |
| } |
|
|
| |
| |
| |
| common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString())) |
|
|
| priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta()) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeModelPriceError), |
| } |
| } |
|
|
| adaptor.Init(info) |
|
|
| var convertedRequest any |
| |
| switch info.RelayMode { |
| case relayconstant.RelayModeEmbeddings: |
| |
| if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok { |
| convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid embedding request type"), |
| newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| case relayconstant.RelayModeImagesGenerations: |
| |
| if imageReq, ok := request.(*dto.ImageRequest); ok { |
| convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid image request type"), |
| newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| case relayconstant.RelayModeRerank: |
| |
| if rerankReq, ok := request.(*dto.RerankRequest); ok { |
| convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid rerank request type"), |
| newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| case relayconstant.RelayModeResponses: |
| |
| if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok { |
| convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid response request type"), |
| newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| default: |
| |
| if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok { |
| convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid general request type"), |
| newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| } |
|
|
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| jsonData, err := json.Marshal(convertedRequest) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed), |
| } |
| } |
| requestBody := bytes.NewBuffer(jsonData) |
| c.Request.Body = io.NopCloser(requestBody) |
| resp, err := adaptor.DoRequest(c, info, requestBody) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError), |
| } |
| } |
| var httpResp *http.Response |
| if resp != nil { |
| httpResp = resp.(*http.Response) |
| if httpResp.StatusCode != http.StatusOK { |
| err := service.RelayErrorHandler(c.Request.Context(), httpResp, true) |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError), |
| } |
| } |
| } |
| usageA, respErr := adaptor.DoResponse(c, httpResp, info) |
| if respErr != nil { |
| return testResult{ |
| context: c, |
| localErr: respErr, |
| newAPIError: respErr, |
| } |
| } |
| if usageA == nil { |
| return testResult{ |
| context: c, |
| localErr: errors.New("usage is nil"), |
| newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError), |
| } |
| } |
| usage := usageA.(*dto.Usage) |
| result := w.Result() |
| respBody, err := io.ReadAll(result.Body) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), |
| } |
| } |
| info.PromptTokens = usage.PromptTokens |
|
|
| quota := 0 |
| if !priceData.UsePrice { |
| quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) |
| quota = int(math.Round(float64(quota) * priceData.ModelRatio)) |
| if priceData.ModelRatio != 0 && quota <= 0 { |
| quota = 1 |
| } |
| } else { |
| quota = int(priceData.ModelPrice * common.QuotaPerUnit) |
| } |
| tok := time.Now() |
| milliseconds := tok.Sub(tik).Milliseconds() |
| consumedTime := float64(milliseconds) / 1000.0 |
| other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, |
| usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) |
| model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{ |
| ChannelId: channel.Id, |
| PromptTokens: usage.PromptTokens, |
| CompletionTokens: usage.CompletionTokens, |
| ModelName: info.OriginModelName, |
| TokenName: "模型测试", |
| Quota: quota, |
| Content: "模型测试", |
| UseTimeSeconds: int(consumedTime), |
| IsStream: info.IsStream, |
| Group: info.UsingGroup, |
| Other: other, |
| }) |
| common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) |
| return testResult{ |
| context: c, |
| localErr: nil, |
| newAPIError: nil, |
| } |
| } |
|
|
| func buildTestRequest(model string, endpointType string) dto.Request { |
| |
| if endpointType != "" { |
| switch constant.EndpointType(endpointType) { |
| case constant.EndpointTypeEmbeddings: |
| |
| return &dto.EmbeddingRequest{ |
| Model: model, |
| Input: []any{"hello world"}, |
| } |
| case constant.EndpointTypeImageGeneration: |
| |
| return &dto.ImageRequest{ |
| Model: model, |
| Prompt: "a cute cat", |
| N: 1, |
| Size: "1024x1024", |
| } |
| case constant.EndpointTypeJinaRerank: |
| |
| return &dto.RerankRequest{ |
| Model: model, |
| Query: "What is Deep Learning?", |
| Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."}, |
| TopN: 2, |
| } |
| case constant.EndpointTypeOpenAIResponse: |
| |
| return &dto.OpenAIResponsesRequest{ |
| Model: model, |
| Input: json.RawMessage("\"hi\""), |
| } |
| case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI: |
| |
| maxTokens := uint(10) |
| if constant.EndpointType(endpointType) == constant.EndpointTypeGemini { |
| maxTokens = 3000 |
| } |
| return &dto.GeneralOpenAIRequest{ |
| Model: model, |
| Stream: false, |
| Messages: []dto.Message{ |
| { |
| Role: "user", |
| Content: "hi", |
| }, |
| }, |
| MaxTokens: maxTokens, |
| } |
| } |
| } |
|
|
| |
| |
| if strings.Contains(strings.ToLower(model), "embedding") || |
| strings.HasPrefix(model, "m3e") || |
| strings.Contains(model, "bge-") { |
| |
| return &dto.EmbeddingRequest{ |
| Model: model, |
| Input: []any{"hello world"}, |
| } |
| } |
|
|
| |
| testRequest := &dto.GeneralOpenAIRequest{ |
| Model: model, |
| Stream: false, |
| Messages: []dto.Message{ |
| { |
| Role: "user", |
| Content: "hi", |
| }, |
| }, |
| } |
|
|
| if strings.HasPrefix(model, "o") { |
| testRequest.MaxCompletionTokens = 10 |
| } else if strings.Contains(model, "thinking") { |
| if !strings.Contains(model, "claude") { |
| testRequest.MaxTokens = 50 |
| } |
| } else if strings.Contains(model, "gemini") { |
| testRequest.MaxTokens = 3000 |
| } else { |
| testRequest.MaxTokens = 10 |
| } |
|
|
| return testRequest |
| } |
|
|
| func TestChannel(c *gin.Context) { |
| channelId, err := strconv.Atoi(c.Param("id")) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| channel, err := model.CacheGetChannel(channelId) |
| if err != nil { |
| channel, err = model.GetChannelById(channelId, true) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| } |
| |
| |
| |
| |
| |
| testModel := c.Query("model") |
| endpointType := c.Query("endpoint_type") |
| tik := time.Now() |
| result := testChannel(channel, testModel, endpointType) |
| if result.localErr != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": result.localErr.Error(), |
| "time": 0.0, |
| }) |
| return |
| } |
| tok := time.Now() |
| milliseconds := tok.Sub(tik).Milliseconds() |
| go channel.UpdateResponseTime(milliseconds) |
| consumedTime := float64(milliseconds) / 1000.0 |
| if result.newAPIError != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": result.newAPIError.Error(), |
| "time": consumedTime, |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "time": consumedTime, |
| }) |
| } |
|
|
| var testAllChannelsLock sync.Mutex |
| var testAllChannelsRunning bool = false |
|
|
| func testAllChannels(notify bool) error { |
|
|
| testAllChannelsLock.Lock() |
| if testAllChannelsRunning { |
| testAllChannelsLock.Unlock() |
| return errors.New("测试已在运行中") |
| } |
| testAllChannelsRunning = true |
| testAllChannelsLock.Unlock() |
| channels, getChannelErr := model.GetAllChannels(0, 0, true, false) |
| if getChannelErr != nil { |
| return getChannelErr |
| } |
| var disableThreshold = int64(common.ChannelDisableThreshold * 1000) |
| if disableThreshold == 0 { |
| disableThreshold = 10000000 |
| } |
| gopool.Go(func() { |
| |
| defer func() { |
| testAllChannelsLock.Lock() |
| testAllChannelsRunning = false |
| testAllChannelsLock.Unlock() |
| }() |
|
|
| for _, channel := range channels { |
| isChannelEnabled := channel.Status == common.ChannelStatusEnabled |
| tik := time.Now() |
| result := testChannel(channel, "", "") |
| tok := time.Now() |
| milliseconds := tok.Sub(tik).Milliseconds() |
|
|
| shouldBanChannel := false |
| newAPIError := result.newAPIError |
| |
| if newAPIError != nil { |
| shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError) |
| } |
|
|
| |
| if common.AutomaticDisableChannelEnabled && !shouldBanChannel { |
| if milliseconds > disableThreshold { |
| err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) |
| newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout) |
| shouldBanChannel = true |
| } |
| } |
|
|
| |
| if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { |
| processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) |
| } |
|
|
| |
| if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) { |
| service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name) |
| } |
|
|
| channel.UpdateResponseTime(milliseconds) |
| time.Sleep(common.RequestInterval) |
| } |
|
|
| if notify { |
| service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") |
| } |
| }) |
| return nil |
| } |
|
|
| func TestAllChannels(c *gin.Context) { |
| err := testAllChannels(true) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| }) |
| } |
|
|
| var autoTestChannelsOnce sync.Once |
|
|
| func AutomaticallyTestChannels() { |
| |
| if !common.IsMasterNode { |
| return |
| } |
| autoTestChannelsOnce.Do(func() { |
| for { |
| if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled { |
| time.Sleep(1 * time.Minute) |
| continue |
| } |
| for { |
| frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes |
| time.Sleep(time.Duration(int(math.Round(frequency))) * time.Minute) |
| common.SysLog(fmt.Sprintf("automatically test channels with interval %f minutes", frequency)) |
| common.SysLog("automatically testing all channels") |
| _ = testAllChannels(false) |
| common.SysLog("automatically channel test finished") |
| if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled { |
| break |
| } |
| } |
| } |
| }) |
| } |
|
|