| package controller |
|
|
| import ( |
| "encoding/json" |
| "sort" |
| "strconv" |
| "strings" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/model" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| |
| func GetAllModelsMeta(c *gin.Context) { |
|
|
| pageInfo := common.GetPageQuery(c) |
| modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| |
| enrichModels(modelsMeta) |
| var total int64 |
| model.DB.Model(&model.Model{}).Count(&total) |
|
|
| |
| vendorCounts, _ := model.GetVendorModelCounts() |
|
|
| pageInfo.SetTotal(int(total)) |
| pageInfo.SetItems(modelsMeta) |
| common.ApiSuccess(c, gin.H{ |
| "items": modelsMeta, |
| "total": total, |
| "page": pageInfo.GetPage(), |
| "page_size": pageInfo.GetPageSize(), |
| "vendor_counts": vendorCounts, |
| }) |
| } |
|
|
| |
| func SearchModelsMeta(c *gin.Context) { |
|
|
| keyword := c.Query("keyword") |
| vendor := c.Query("vendor") |
| pageInfo := common.GetPageQuery(c) |
|
|
| modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| |
| enrichModels(modelsMeta) |
| pageInfo.SetTotal(int(total)) |
| pageInfo.SetItems(modelsMeta) |
| common.ApiSuccess(c, pageInfo) |
| } |
|
|
| |
| func GetModelMeta(c *gin.Context) { |
| idStr := c.Param("id") |
| id, err := strconv.Atoi(idStr) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| var m model.Model |
| if err := model.DB.First(&m, id).Error; err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| enrichModels([]*model.Model{&m}) |
| common.ApiSuccess(c, &m) |
| } |
|
|
| |
| func CreateModelMeta(c *gin.Context) { |
| var m model.Model |
| if err := c.ShouldBindJSON(&m); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| if m.ModelName == "" { |
| common.ApiErrorMsg(c, "模型名称不能为空") |
| return |
| } |
| |
| if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { |
| common.ApiError(c, err) |
| return |
| } else if dup { |
| common.ApiErrorMsg(c, "模型名称已存在") |
| return |
| } |
|
|
| if err := m.Insert(); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| model.RefreshPricing() |
| common.ApiSuccess(c, &m) |
| } |
|
|
| |
| func UpdateModelMeta(c *gin.Context) { |
| statusOnly := c.Query("status_only") == "true" |
|
|
| var m model.Model |
| if err := c.ShouldBindJSON(&m); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| if m.Id == 0 { |
| common.ApiErrorMsg(c, "缺少模型 ID") |
| return |
| } |
|
|
| if statusOnly { |
| |
| if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| } else { |
| |
| if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { |
| common.ApiError(c, err) |
| return |
| } else if dup { |
| common.ApiErrorMsg(c, "模型名称已存在") |
| return |
| } |
|
|
| if err := m.Update(); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| } |
| model.RefreshPricing() |
| common.ApiSuccess(c, &m) |
| } |
|
|
| |
| func DeleteModelMeta(c *gin.Context) { |
| idStr := c.Param("id") |
| id, err := strconv.Atoi(idStr) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| model.RefreshPricing() |
| common.ApiSuccess(c, nil) |
| } |
|
|
| |
| func enrichModels(models []*model.Model) { |
| if len(models) == 0 { |
| return |
| } |
|
|
| |
| exactNames := make([]string, 0) |
| exactIdx := make(map[string][]int) |
| ruleIndices := make([]int, 0) |
| for i, m := range models { |
| if m == nil { |
| continue |
| } |
| if m.NameRule == model.NameRuleExact { |
| exactNames = append(exactNames, m.ModelName) |
| exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i) |
| } else { |
| ruleIndices = append(ruleIndices, i) |
| } |
| } |
|
|
| |
| channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames) |
|
|
| |
| for name, indices := range exactIdx { |
| chs := channelsByModel[name] |
| for _, idx := range indices { |
| mm := models[idx] |
| if mm.Endpoints == "" { |
| eps := model.GetModelSupportEndpointTypes(mm.ModelName) |
| if b, err := json.Marshal(eps); err == nil { |
| mm.Endpoints = string(b) |
| } |
| } |
| mm.BoundChannels = chs |
| mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName) |
| mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName) |
| } |
| } |
|
|
| if len(ruleIndices) == 0 { |
| return |
| } |
|
|
| |
| pricings := model.GetPricing() |
|
|
| |
| matchedNamesByIdx := make(map[int][]string) |
| endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{}) |
| groupSetByIdx := make(map[int]map[string]struct{}) |
| quotaSetByIdx := make(map[int]map[int]struct{}) |
|
|
| for _, p := range pricings { |
| for _, idx := range ruleIndices { |
| mm := models[idx] |
| var matched bool |
| switch mm.NameRule { |
| case model.NameRulePrefix: |
| matched = strings.HasPrefix(p.ModelName, mm.ModelName) |
| case model.NameRuleSuffix: |
| matched = strings.HasSuffix(p.ModelName, mm.ModelName) |
| case model.NameRuleContains: |
| matched = strings.Contains(p.ModelName, mm.ModelName) |
| } |
| if !matched { |
| continue |
| } |
| matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName) |
|
|
| es := endpointSetByIdx[idx] |
| if es == nil { |
| es = make(map[constant.EndpointType]struct{}) |
| endpointSetByIdx[idx] = es |
| } |
| for _, et := range p.SupportedEndpointTypes { |
| es[et] = struct{}{} |
| } |
|
|
| gs := groupSetByIdx[idx] |
| if gs == nil { |
| gs = make(map[string]struct{}) |
| groupSetByIdx[idx] = gs |
| } |
| for _, g := range p.EnableGroup { |
| gs[g] = struct{}{} |
| } |
|
|
| qs := quotaSetByIdx[idx] |
| if qs == nil { |
| qs = make(map[int]struct{}) |
| quotaSetByIdx[idx] = qs |
| } |
| qs[p.QuotaType] = struct{}{} |
| } |
| } |
|
|
| |
| allMatchedSet := make(map[string]struct{}) |
| for _, names := range matchedNamesByIdx { |
| for _, n := range names { |
| allMatchedSet[n] = struct{}{} |
| } |
| } |
| allMatched := make([]string, 0, len(allMatchedSet)) |
| for n := range allMatchedSet { |
| allMatched = append(allMatched, n) |
| } |
| matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched) |
|
|
| |
| for _, idx := range ruleIndices { |
| mm := models[idx] |
|
|
| |
| if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" { |
| eps := make([]constant.EndpointType, 0, len(es)) |
| for et := range es { |
| eps = append(eps, et) |
| } |
| if b, err := json.Marshal(eps); err == nil { |
| mm.Endpoints = string(b) |
| } |
| } |
|
|
| |
| if gs, ok := groupSetByIdx[idx]; ok { |
| groups := make([]string, 0, len(gs)) |
| for g := range gs { |
| groups = append(groups, g) |
| } |
| mm.EnableGroups = groups |
| } |
|
|
| |
| if qs, ok := quotaSetByIdx[idx]; ok { |
| arr := make([]int, 0, len(qs)) |
| for k := range qs { |
| arr = append(arr, k) |
| } |
| sort.Ints(arr) |
| mm.QuotaTypes = arr |
| } |
|
|
| |
| names := matchedNamesByIdx[idx] |
| channelSet := make(map[string]model.BoundChannel) |
| for _, n := range names { |
| for _, ch := range matchedChannelsByModel[n] { |
| key := ch.Name + "_" + strconv.Itoa(ch.Type) |
| channelSet[key] = ch |
| } |
| } |
| if len(channelSet) > 0 { |
| chs := make([]model.BoundChannel, 0, len(channelSet)) |
| for _, ch := range channelSet { |
| chs = append(chs, ch) |
| } |
| mm.BoundChannels = chs |
| } |
|
|
| |
| mm.MatchedModels = names |
| mm.MatchedCount = len(names) |
| } |
| } |
|
|