| package controller |
|
|
| import ( |
| "context" |
| "encoding/json" |
| "fmt" |
| "io" |
| "net" |
| "net/http" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/logger" |
|
|
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/model" |
| "github.com/QuantumNous/new-api/setting/ratio_setting" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| const ( |
| defaultTimeoutSeconds = 10 |
| defaultEndpoint = "/api/ratio_config" |
| maxConcurrentFetches = 8 |
| maxRatioConfigBytes = 10 << 20 |
| floatEpsilon = 1e-9 |
| ) |
|
|
| func nearlyEqual(a, b float64) bool { |
| if a > b { |
| return a-b < floatEpsilon |
| } |
| return b-a < floatEpsilon |
| } |
|
|
| func valuesEqual(a, b interface{}) bool { |
| af, aok := a.(float64) |
| bf, bok := b.(float64) |
| if aok && bok { |
| return nearlyEqual(af, bf) |
| } |
| return a == b |
| } |
|
|
| var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} |
|
|
| type upstreamResult struct { |
| Name string `json:"name"` |
| Data map[string]any `json:"data,omitempty"` |
| Err string `json:"err,omitempty"` |
| } |
|
|
| func FetchUpstreamRatios(c *gin.Context) { |
| var req dto.UpstreamRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) |
| return |
| } |
|
|
| if req.Timeout <= 0 { |
| req.Timeout = defaultTimeoutSeconds |
| } |
|
|
| var upstreams []dto.UpstreamDTO |
|
|
| if len(req.Upstreams) > 0 { |
| for _, u := range req.Upstreams { |
| if strings.HasPrefix(u.BaseURL, "http") { |
| if u.Endpoint == "" { |
| u.Endpoint = defaultEndpoint |
| } |
| u.BaseURL = strings.TrimRight(u.BaseURL, "/") |
| upstreams = append(upstreams, u) |
| } |
| } |
| } else if len(req.ChannelIDs) > 0 { |
| intIds := make([]int, 0, len(req.ChannelIDs)) |
| for _, id64 := range req.ChannelIDs { |
| intIds = append(intIds, int(id64)) |
| } |
| dbChannels, err := model.GetChannelsByIds(intIds) |
| if err != nil { |
| logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) |
| c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) |
| return |
| } |
| for _, ch := range dbChannels { |
| if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { |
| upstreams = append(upstreams, dto.UpstreamDTO{ |
| ID: ch.Id, |
| Name: ch.Name, |
| BaseURL: strings.TrimRight(base, "/"), |
| Endpoint: "", |
| }) |
| } |
| } |
| } |
|
|
| if len(upstreams) == 0 { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) |
| return |
| } |
|
|
| var wg sync.WaitGroup |
| ch := make(chan upstreamResult, len(upstreams)) |
|
|
| sem := make(chan struct{}, maxConcurrentFetches) |
|
|
| dialer := &net.Dialer{Timeout: 10 * time.Second} |
| transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second} |
| transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { |
| host, _, err := net.SplitHostPort(addr) |
| if err != nil { |
| host = addr |
| } |
| |
| if strings.HasSuffix(host, "github.io") { |
| if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil { |
| return conn, nil |
| } |
| return dialer.DialContext(ctx, "tcp6", addr) |
| } |
| return dialer.DialContext(ctx, network, addr) |
| } |
| client := &http.Client{Transport: transport} |
|
|
| for _, chn := range upstreams { |
| wg.Add(1) |
| go func(chItem dto.UpstreamDTO) { |
| defer wg.Done() |
|
|
| sem <- struct{}{} |
| defer func() { <-sem }() |
|
|
| endpoint := chItem.Endpoint |
| var fullURL string |
| if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { |
| fullURL = endpoint |
| } else { |
| if endpoint == "" { |
| endpoint = defaultEndpoint |
| } else if !strings.HasPrefix(endpoint, "/") { |
| endpoint = "/" + endpoint |
| } |
| fullURL = chItem.BaseURL + endpoint |
| } |
|
|
| uniqueName := chItem.Name |
| if chItem.ID != 0 { |
| uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) |
| } |
|
|
| ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) |
| defer cancel() |
|
|
| httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) |
| if err != nil { |
| logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) |
| ch <- upstreamResult{Name: uniqueName, Err: err.Error()} |
| return |
| } |
|
|
| |
| var resp *http.Response |
| var lastErr error |
| for attempt := 0; attempt < 3; attempt++ { |
| resp, lastErr = client.Do(httpReq) |
| if lastErr == nil { |
| break |
| } |
| time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond) |
| } |
| if lastErr != nil { |
| logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error()) |
| ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()} |
| return |
| } |
| defer resp.Body.Close() |
| if resp.StatusCode != http.StatusOK { |
| logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) |
| ch <- upstreamResult{Name: uniqueName, Err: resp.Status} |
| return |
| } |
|
|
| |
| if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") { |
| logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct) |
| } |
| limited := io.LimitReader(resp.Body, maxRatioConfigBytes) |
| |
| |
| |
| var body struct { |
| Success bool `json:"success"` |
| Data json.RawMessage `json:"data"` |
| Message string `json:"message"` |
| } |
|
|
| if err := json.NewDecoder(limited).Decode(&body); err != nil { |
| logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) |
| ch <- upstreamResult{Name: uniqueName, Err: err.Error()} |
| return |
| } |
|
|
| if !body.Success { |
| ch <- upstreamResult{Name: uniqueName, Err: body.Message} |
| return |
| } |
|
|
| |
|
|
| |
| var type1Data map[string]any |
| if err := json.Unmarshal(body.Data, &type1Data); err == nil { |
| |
| isType1 := false |
| for _, rt := range ratioTypes { |
| if _, ok := type1Data[rt]; ok { |
| isType1 = true |
| break |
| } |
| } |
| if isType1 { |
| ch <- upstreamResult{Name: uniqueName, Data: type1Data} |
| return |
| } |
| } |
|
|
| |
| var pricingItems []struct { |
| ModelName string `json:"model_name"` |
| QuotaType int `json:"quota_type"` |
| ModelRatio float64 `json:"model_ratio"` |
| ModelPrice float64 `json:"model_price"` |
| CompletionRatio float64 `json:"completion_ratio"` |
| } |
| if err := json.Unmarshal(body.Data, &pricingItems); err != nil { |
| logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) |
| ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} |
| return |
| } |
|
|
| modelRatioMap := make(map[string]float64) |
| completionRatioMap := make(map[string]float64) |
| modelPriceMap := make(map[string]float64) |
|
|
| for _, item := range pricingItems { |
| if item.QuotaType == 1 { |
| modelPriceMap[item.ModelName] = item.ModelPrice |
| } else { |
| modelRatioMap[item.ModelName] = item.ModelRatio |
| |
| completionRatioMap[item.ModelName] = item.CompletionRatio |
| } |
| } |
|
|
| converted := make(map[string]any) |
|
|
| if len(modelRatioMap) > 0 { |
| ratioAny := make(map[string]any, len(modelRatioMap)) |
| for k, v := range modelRatioMap { |
| ratioAny[k] = v |
| } |
| converted["model_ratio"] = ratioAny |
| } |
|
|
| if len(completionRatioMap) > 0 { |
| compAny := make(map[string]any, len(completionRatioMap)) |
| for k, v := range completionRatioMap { |
| compAny[k] = v |
| } |
| converted["completion_ratio"] = compAny |
| } |
|
|
| if len(modelPriceMap) > 0 { |
| priceAny := make(map[string]any, len(modelPriceMap)) |
| for k, v := range modelPriceMap { |
| priceAny[k] = v |
| } |
| converted["model_price"] = priceAny |
| } |
|
|
| ch <- upstreamResult{Name: uniqueName, Data: converted} |
| }(chn) |
| } |
|
|
| wg.Wait() |
| close(ch) |
|
|
| localData := ratio_setting.GetExposedData() |
|
|
| var testResults []dto.TestResult |
| var successfulChannels []struct { |
| name string |
| data map[string]any |
| } |
|
|
| for r := range ch { |
| if r.Err != "" { |
| testResults = append(testResults, dto.TestResult{ |
| Name: r.Name, |
| Status: "error", |
| Error: r.Err, |
| }) |
| } else { |
| testResults = append(testResults, dto.TestResult{ |
| Name: r.Name, |
| Status: "success", |
| }) |
| successfulChannels = append(successfulChannels, struct { |
| name string |
| data map[string]any |
| }{name: r.Name, data: r.Data}) |
| } |
| } |
|
|
| differences := buildDifferences(localData, successfulChannels) |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "data": gin.H{ |
| "differences": differences, |
| "test_results": testResults, |
| }, |
| }) |
| } |
|
|
| func buildDifferences(localData map[string]any, successfulChannels []struct { |
| name string |
| data map[string]any |
| }) map[string]map[string]dto.DifferenceItem { |
| differences := make(map[string]map[string]dto.DifferenceItem) |
|
|
| allModels := make(map[string]struct{}) |
|
|
| for _, ratioType := range ratioTypes { |
| if localRatioAny, ok := localData[ratioType]; ok { |
| if localRatio, ok := localRatioAny.(map[string]float64); ok { |
| for modelName := range localRatio { |
| allModels[modelName] = struct{}{} |
| } |
| } |
| } |
| } |
|
|
| for _, channel := range successfulChannels { |
| for _, ratioType := range ratioTypes { |
| if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { |
| for modelName := range upstreamRatio { |
| allModels[modelName] = struct{}{} |
| } |
| } |
| } |
| } |
|
|
| confidenceMap := make(map[string]map[string]bool) |
|
|
| |
| for _, channel := range successfulChannels { |
| confidenceMap[channel.name] = make(map[string]bool) |
|
|
| modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) |
| completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) |
|
|
| if hasModelRatio && hasCompletionRatio { |
| |
| for modelName := range allModels { |
| |
| confidenceMap[channel.name][modelName] = true |
|
|
| |
| if modelRatioVal, ok := modelRatios[modelName]; ok { |
| if completionRatioVal, ok := completionRatios[modelName]; ok { |
| |
| if modelRatioFloat, ok := modelRatioVal.(float64); ok { |
| if completionRatioFloat, ok := completionRatioVal.(float64); ok { |
| if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { |
| confidenceMap[channel.name][modelName] = false |
| } |
| } |
| } |
| } |
| } |
| } |
| } else { |
| |
| for modelName := range allModels { |
| confidenceMap[channel.name][modelName] = true |
| } |
| } |
| } |
|
|
| for modelName := range allModels { |
| for _, ratioType := range ratioTypes { |
| var localValue interface{} = nil |
| if localRatioAny, ok := localData[ratioType]; ok { |
| if localRatio, ok := localRatioAny.(map[string]float64); ok { |
| if val, exists := localRatio[modelName]; exists { |
| localValue = val |
| } |
| } |
| } |
|
|
| upstreamValues := make(map[string]interface{}) |
| confidenceValues := make(map[string]bool) |
| hasUpstreamValue := false |
| hasDifference := false |
|
|
| for _, channel := range successfulChannels { |
| var upstreamValue interface{} = nil |
|
|
| if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { |
| if val, exists := upstreamRatio[modelName]; exists { |
| upstreamValue = val |
| hasUpstreamValue = true |
|
|
| if localValue != nil && !valuesEqual(localValue, val) { |
| hasDifference = true |
| } else if valuesEqual(localValue, val) { |
| upstreamValue = "same" |
| } |
| } |
| } |
| if upstreamValue == nil && localValue == nil { |
| upstreamValue = "same" |
| } |
|
|
| if localValue == nil && upstreamValue != nil && upstreamValue != "same" { |
| hasDifference = true |
| } |
|
|
| upstreamValues[channel.name] = upstreamValue |
|
|
| confidenceValues[channel.name] = confidenceMap[channel.name][modelName] |
| } |
|
|
| shouldInclude := false |
|
|
| if localValue != nil { |
| if hasDifference { |
| shouldInclude = true |
| } |
| } else { |
| if hasUpstreamValue { |
| shouldInclude = true |
| } |
| } |
|
|
| if shouldInclude { |
| if differences[modelName] == nil { |
| differences[modelName] = make(map[string]dto.DifferenceItem) |
| } |
| differences[modelName][ratioType] = dto.DifferenceItem{ |
| Current: localValue, |
| Upstreams: upstreamValues, |
| Confidence: confidenceValues, |
| } |
| } |
| } |
| } |
|
|
| channelHasDiff := make(map[string]bool) |
| for _, ratioMap := range differences { |
| for _, item := range ratioMap { |
| for chName, val := range item.Upstreams { |
| if val != nil && val != "same" { |
| channelHasDiff[chName] = true |
| } |
| } |
| } |
| } |
|
|
| for modelName, ratioMap := range differences { |
| for ratioType, item := range ratioMap { |
| for chName := range item.Upstreams { |
| if !channelHasDiff[chName] { |
| delete(item.Upstreams, chName) |
| delete(item.Confidence, chName) |
| } |
| } |
|
|
| allSame := true |
| for _, v := range item.Upstreams { |
| if v != "same" { |
| allSame = false |
| break |
| } |
| } |
| if len(item.Upstreams) == 0 || allSame { |
| delete(ratioMap, ratioType) |
| } else { |
| differences[modelName][ratioType] = item |
| } |
| } |
|
|
| if len(ratioMap) == 0 { |
| delete(differences, modelName) |
| } |
| } |
|
|
| return differences |
| } |
|
|
| func GetSyncableChannels(c *gin.Context) { |
| channels, err := model.GetAllChannels(0, 0, true, false) |
| if err != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| return |
| } |
|
|
| var syncableChannels []dto.SyncableChannel |
| for _, channel := range channels { |
| if channel.GetBaseURL() != "" { |
| syncableChannels = append(syncableChannels, dto.SyncableChannel{ |
| ID: channel.Id, |
| Name: channel.Name, |
| BaseURL: channel.GetBaseURL(), |
| Status: channel.Status, |
| }) |
| } |
| } |
|
|
| syncableChannels = append(syncableChannels, dto.SyncableChannel{ |
| ID: -100, |
| Name: "官方倍率预设", |
| BaseURL: "https://basellm.github.io", |
| Status: 1, |
| }) |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": syncableChannels, |
| }) |
| } |
|
|