diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..ffb8c21be128625dde5b08c77c890a603f406e70 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,33 @@ +FROM node:16 as builder + +WORKDIR /build +COPY web/package.json . +RUN npm install +COPY ./web . +COPY ./VERSION . +RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build + +FROM golang AS builder2 + +ENV GO111MODULE=on \ + CGO_ENABLED=1 \ + GOOS=linux + +WORKDIR /build +ADD go.mod go.sum ./ +RUN go mod download +COPY . . +COPY --from=builder /build/build ./web/build +RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api + +FROM alpine + +RUN apk update \ + && apk upgrade \ + && apk add --no-cache ca-certificates tzdata \ + && update-ca-certificates 2>/dev/null || true + +COPY --from=builder2 /build/one-api / +EXPOSE 3000 +WORKDIR /data +ENTRYPOINT ["/one-api"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4bf5d1b7cb1bafb2d35b797859ac8e7f36f57180 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 JustSong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/VERSION b/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/bin/migration_v0.2-v0.3.sql b/bin/migration_v0.2-v0.3.sql new file mode 100644 index 0000000000000000000000000000000000000000..6b08d7bf9f4aa70e6ceeaa56137cd1bc8db075f5 --- /dev/null +++ b/bin/migration_v0.2-v0.3.sql @@ -0,0 +1,6 @@ +UPDATE users +SET quota = quota + ( + SELECT SUM(remain_quota) + FROM tokens + WHERE tokens.user_id = users.id +) diff --git a/bin/migration_v0.3-v0.4.sql b/bin/migration_v0.3-v0.4.sql new file mode 100644 index 0000000000000000000000000000000000000000..e6103c29acff677acf5d88f5df380d076e5e129f --- /dev/null +++ b/bin/migration_v0.3-v0.4.sql @@ -0,0 +1,17 @@ +INSERT INTO abilities (`group`, model, channel_id, enabled) +SELECT c.`group`, m.model, c.id, 1 +FROM channels c +CROSS JOIN ( + SELECT 'gpt-3.5-turbo' AS model UNION ALL + SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL + SELECT 'gpt-4' AS model UNION ALL + SELECT 'gpt-4-0314' AS model +) AS m +WHERE c.status = 1 + AND NOT EXISTS ( + SELECT 1 + FROM abilities a + WHERE a.`group` = c.`group` + AND a.model = m.model + AND a.channel_id = c.id +); diff --git a/bin/time_test.sh b/bin/time_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..2cde4a65bbd4b3d7e60ca55b504d78b02b999f9d --- /dev/null +++ b/bin/time_test.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +if [ $# -lt 3 ]; then + echo "Usage: time_test.sh []" + exit 1 +fi + +domain=$1 +key=$2 +count=$3 +model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo + +total_time=0 +times=() + +for ((i=1; i<=count; i++)); do + result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \ + https://"$domain"/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $key" \ + -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}') + http_code=$(echo "$result" | awk '{print $1}') + time=$(echo "$result" | awk '{print $2}') + echo "HTTP status code: $http_code, Time taken: $time" + total_time=$(bc <<< "$total_time + $time") + times+=("$time") +done + +average_time=$(echo "scale=4; $total_time / $count" | bc) + +sum_of_squares=0 +for time in "${times[@]}"; do + difference=$(echo "scale=4; $time - $average_time" | bc) + square=$(echo "scale=4; $difference * $difference" | bc) + sum_of_squares=$(echo "scale=4; $sum_of_squares + $square" | bc) +done + +standard_deviation=$(echo "scale=4; sqrt($sum_of_squares / $count)" | bc) + +echo "Average time: $average_time±$standard_deviation" diff --git a/common/constants.go b/common/constants.go new file mode 100644 index 0000000000000000000000000000000000000000..e5211e3d21452004515cffd53814dd06abc0f9bd --- /dev/null +++ b/common/constants.go @@ -0,0 +1,202 @@ +package common + +import ( + "os" + "strconv" + "sync" + "time" + + "github.com/google/uuid" +) + +var StartTime = time.Now().Unix() // unit: second +var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change +var SystemName = "One API" +var ServerAddress = "http://localhost:3000" +var Footer = "" +var Logo = "" +var TopUpLink = "" +var ChatLink = "" +var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens +var DisplayInCurrencyEnabled = true +var DisplayTokenStatEnabled = true + +var UsingSQLite = false + +// Any options with "Secret", "Token" in its key won't be return by GetOptions + +var SessionSecret = uuid.New().String() +var SQLitePath = "one-api.db" + +var OptionMap map[string]string +var OptionMapRWMutex sync.RWMutex + +var ItemsPerPage = 10 +var MaxRecentItems = 100 + +var PasswordLoginEnabled = true +var PasswordRegisterEnabled = true +var EmailVerificationEnabled = false +var GitHubOAuthEnabled = false +var WeChatAuthEnabled = false +var TurnstileCheckEnabled = false +var RegisterEnabled = true + +var EmailDomainRestrictionEnabled = false +var EmailDomainWhitelist = []string{ + "gmail.com", + "163.com", + "126.com", + "qq.com", + "outlook.com", + "hotmail.com", + "icloud.com", + "yahoo.com", + "foxmail.com", +} + +var DebugEnabled = os.Getenv("DEBUG") == "true" + +var LogConsumeEnabled = true + +var SMTPServer = "" +var SMTPPort = 587 +var SMTPAccount = "" +var SMTPFrom = "" +var SMTPToken = "" + +var GitHubClientId = "" +var GitHubClientSecret = "" + +var WeChatServerAddress = "" +var WeChatServerToken = "" +var WeChatAccountQRCodeImageURL = "" + +var TurnstileSiteKey = "" +var TurnstileSecretKey = "" + +var QuotaForNewUser = 0 +var QuotaForInviter = 0 +var QuotaForInvitee = 0 +var ChannelDisableThreshold = 5.0 +var AutomaticDisableChannelEnabled = false +var QuotaRemindThreshold = 1000 +var PreConsumedQuota = 500 +var ApproximateTokenEnabled = false +var RetryTimes = 0 + +var RootUserEmail = "" + +var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" + +var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) +var RequestInterval = time.Duration(requestInterval) * time.Second + +var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY + +const ( + RoleGuestUser = 0 + RoleCommonUser = 1 + RoleAdminUser = 10 + RoleRootUser = 100 +) + +var ( + FileUploadPermission = RoleGuestUser + FileDownloadPermission = RoleGuestUser + ImageUploadPermission = RoleGuestUser + ImageDownloadPermission = RoleGuestUser +) + +// All duration's unit is seconds +// Shouldn't larger then RateLimitKeyExpirationDuration +var ( + GlobalApiRateLimitNum = 180 + GlobalApiRateLimitDuration int64 = 3 * 60 + + GlobalWebRateLimitNum = 60 + GlobalWebRateLimitDuration int64 = 3 * 60 + + UploadRateLimitNum = 10 + UploadRateLimitDuration int64 = 60 + + DownloadRateLimitNum = 10 + DownloadRateLimitDuration int64 = 60 + + CriticalRateLimitNum = 20 + CriticalRateLimitDuration int64 = 20 * 60 +) + +var RateLimitKeyExpirationDuration = 20 * time.Minute + +const ( + UserStatusEnabled = 1 // don't use 0, 0 is the default value! + UserStatusDisabled = 2 // also don't use 0 +) + +const ( + TokenStatusEnabled = 1 // don't use 0, 0 is the default value! + TokenStatusDisabled = 2 // also don't use 0 + TokenStatusExpired = 3 + TokenStatusExhausted = 4 +) + +const ( + RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! + RedemptionCodeStatusDisabled = 2 // also don't use 0 + RedemptionCodeStatusUsed = 3 // also don't use 0 +) + +const ( + ChannelStatusUnknown = 0 + ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! + ChannelStatusDisabled = 2 // also don't use 0 +) + +const ( + ChannelTypeUnknown = 0 + ChannelTypeOpenAI = 1 + ChannelTypeAPI2D = 2 + ChannelTypeAzure = 3 + ChannelTypeCloseAI = 4 + ChannelTypeOpenAISB = 5 + ChannelTypeOpenAIMax = 6 + ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 + ChannelTypeAPI2GPT = 12 + ChannelTypeAIGC2D = 13 + ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 + ChannelTypeZhipu = 16 + ChannelTypeAli = 17 + ChannelTypeXunfei = 18 + ChannelType360 = 19 + ChannelTypeOpenRouter = 20 +) + +var ChannelBaseURLs = []string{ + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 +} diff --git a/common/crypto.go b/common/crypto.go new file mode 100644 index 0000000000000000000000000000000000000000..452284161fd5c3768feb12c14ba953b451edb5ed --- /dev/null +++ b/common/crypto.go @@ -0,0 +1,14 @@ +package common + +import "golang.org/x/crypto/bcrypt" + +func Password2Hash(password string) (string, error) { + passwordBytes := []byte(password) + hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) + return string(hashedPassword), err +} + +func ValidatePasswordAndHash(password string, hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} diff --git a/common/custom-event.go b/common/custom-event.go new file mode 100644 index 0000000000000000000000000000000000000000..69da4bc4b1660d834ae16dc2a1122ba47de32be0 --- /dev/null +++ b/common/custom-event.go @@ -0,0 +1,82 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package common + +import ( + "fmt" + "io" + "net/http" + "strings" +) + +type stringWriter interface { + io.Writer + writeString(string) (int, error) +} + +type stringWrapper struct { + io.Writer +} + +func (w stringWrapper) writeString(str string) (int, error) { + return w.Writer.Write([]byte(str)) +} + +func checkWriter(writer io.Writer) stringWriter { + if w, ok := writer.(stringWriter); ok { + return w + } else { + return stringWrapper{writer} + } +} + +// Server-Sent Events +// W3C Working Draft 29 October 2009 +// http://www.w3.org/TR/2009/WD-eventsource-20091029/ + +var contentType = []string{"text/event-stream"} +var noCache = []string{"no-cache"} + +var fieldReplacer = strings.NewReplacer( + "\n", "\\n", + "\r", "\\r") + +var dataReplacer = strings.NewReplacer( + "\n", "\ndata:", + "\r", "\\r") + +type CustomEvent struct { + Event string + Id string + Retry uint + Data interface{} +} + +func encode(writer io.Writer, event CustomEvent) error { + w := checkWriter(writer) + return writeData(w, event.Data) +} + +func writeData(w stringWriter, data interface{}) error { + dataReplacer.WriteString(w, fmt.Sprint(data)) + if strings.HasPrefix(data.(string), "data") { + w.writeString("\n\n") + } + return nil +} + +func (r CustomEvent) Render(w http.ResponseWriter) error { + r.WriteContentType(w) + return encode(w, r) +} + +func (r CustomEvent) WriteContentType(w http.ResponseWriter) { + header := w.Header() + header["Content-Type"] = contentType + + if _, exist := header["Cache-Control"]; !exist { + header["Cache-Control"] = noCache + } +} diff --git a/common/email.go b/common/email.go new file mode 100644 index 0000000000000000000000000000000000000000..74f4cccd9ef7282fad654356c3bf71bf07f1a6f4 --- /dev/null +++ b/common/email.go @@ -0,0 +1,67 @@ +package common + +import ( + "crypto/tls" + "encoding/base64" + "fmt" + "net/smtp" + "strings" +) + +func SendEmail(subject string, receiver string, content string) error { + if SMTPFrom == "" { // for compatibility + SMTPFrom = SMTPAccount + } + encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) + mail := []byte(fmt.Sprintf("To: %s\r\n"+ + "From: %s<%s>\r\n"+ + "Subject: %s\r\n"+ + "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", + receiver, SystemName, SMTPFrom, encodedSubject, content)) + auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) + addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) + to := strings.Split(receiver, ";") + var err error + if SMTPPort == 465 { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: SMTPServer, + } + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) + if err != nil { + return err + } + client, err := smtp.NewClient(conn, SMTPServer) + if err != nil { + return err + } + defer client.Close() + if err = client.Auth(auth); err != nil { + return err + } + if err = client.Mail(SMTPFrom); err != nil { + return err + } + receiverEmails := strings.Split(receiver, ";") + for _, receiver := range receiverEmails { + if err = client.Rcpt(receiver); err != nil { + return err + } + } + w, err := client.Data() + if err != nil { + return err + } + _, err = w.Write(mail) + if err != nil { + return err + } + err = w.Close() + if err != nil { + return err + } + } else { + err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) + } + return err +} diff --git a/common/embed-file-system.go b/common/embed-file-system.go new file mode 100644 index 0000000000000000000000000000000000000000..3ea02cf81dedb099cde02447bd44efaab6c66c68 --- /dev/null +++ b/common/embed-file-system.go @@ -0,0 +1,32 @@ +package common + +import ( + "embed" + "github.com/gin-contrib/static" + "io/fs" + "net/http" +) + +// Credit: https://github.com/gin-contrib/static/issues/19 + +type embedFileSystem struct { + http.FileSystem +} + +func (e embedFileSystem) Exists(prefix string, path string) bool { + _, err := e.Open(path) + if err != nil { + return false + } + return true +} + +func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { + efs, err := fs.Sub(fsEmbed, targetPath) + if err != nil { + panic(err) + } + return embedFileSystem{ + FileSystem: http.FS(efs), + } +} diff --git a/common/gin.go b/common/gin.go new file mode 100644 index 0000000000000000000000000000000000000000..ffa1e2183355bdcb9fd19fdecb716a309050463d --- /dev/null +++ b/common/gin.go @@ -0,0 +1,26 @@ +package common + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "io" +) + +func UnmarshalBodyReusable(c *gin.Context, v any) error { + requestBody, err := io.ReadAll(c.Request.Body) + if err != nil { + return err + } + err = c.Request.Body.Close() + if err != nil { + return err + } + err = json.Unmarshal(requestBody, &v) + if err != nil { + return err + } + // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return nil +} diff --git a/common/group-ratio.go b/common/group-ratio.go new file mode 100644 index 0000000000000000000000000000000000000000..1ec73c780b0c7e449ba55a8a0d3c08acc488cec0 --- /dev/null +++ b/common/group-ratio.go @@ -0,0 +1,31 @@ +package common + +import "encoding/json" + +var GroupRatio = map[string]float64{ + "default": 1, + "vip": 1, + "svip": 1, +} + +func GroupRatio2JSONString() string { + jsonBytes, err := json.Marshal(GroupRatio) + if err != nil { + SysError("error marshalling model ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateGroupRatioByJSONString(jsonStr string) error { + GroupRatio = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &GroupRatio) +} + +func GetGroupRatio(name string) float64 { + ratio, ok := GroupRatio[name] + if !ok { + SysError("group ratio not found: " + name) + return 1 + } + return ratio +} diff --git a/common/init.go b/common/init.go new file mode 100644 index 0000000000000000000000000000000000000000..0f22c69b6a643272b99ce5c5684036aedc10ad76 --- /dev/null +++ b/common/init.go @@ -0,0 +1,57 @@ +package common + +import ( + "flag" + "fmt" + "log" + "os" + "path/filepath" +) + +var ( + Port = flag.Int("port", 3000, "the listening port") + PrintVersion = flag.Bool("version", false, "print version and exit") + PrintHelp = flag.Bool("help", false, "print help and exit") + LogDir = flag.String("log-dir", "", "specify the log directory") +) + +func printHelp() { + fmt.Println("One API " + Version + " - All in one API service for OpenAI API.") + fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.") + fmt.Println("GitHub: https://github.com/songquanpeng/one-api") + fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]") +} + +func init() { + flag.Parse() + + if *PrintVersion { + fmt.Println(Version) + os.Exit(0) + } + + if *PrintHelp { + printHelp() + os.Exit(0) + } + + if os.Getenv("SESSION_SECRET") != "" { + SessionSecret = os.Getenv("SESSION_SECRET") + } + if os.Getenv("SQLITE_PATH") != "" { + SQLitePath = os.Getenv("SQLITE_PATH") + } + if *LogDir != "" { + var err error + *LogDir, err = filepath.Abs(*LogDir) + if err != nil { + log.Fatal(err) + } + if _, err := os.Stat(*LogDir); os.IsNotExist(err) { + err = os.Mkdir(*LogDir, 0777) + if err != nil { + log.Fatal(err) + } + } + } +} diff --git a/common/logger.go b/common/logger.go new file mode 100644 index 0000000000000000000000000000000000000000..3658dbdbfd4053d2e4fa53c4efcf0d7497964403 --- /dev/null +++ b/common/logger.go @@ -0,0 +1,52 @@ +package common + +import ( + "fmt" + "github.com/gin-gonic/gin" + "io" + "log" + "os" + "path/filepath" + "time" +) + +func SetupGinLog() { + if *LogDir != "" { + commonLogPath := filepath.Join(*LogDir, "common.log") + errorLogPath := filepath.Join(*LogDir, "error.log") + commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatal("failed to open log file") + } + errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatal("failed to open log file") + } + gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd) + gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd) + } +} + +func SysLog(s string) { + t := time.Now() + _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) +} + +func SysError(s string) { + t := time.Now() + _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) +} + +func FatalLog(v ...any) { + t := time.Now() + _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) + os.Exit(1) +} + +func LogQuota(quota int) string { + if DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) + } else { + return fmt.Sprintf("%d 点额度", quota) + } +} diff --git a/common/model-ratio.go b/common/model-ratio.go new file mode 100644 index 0000000000000000000000000000000000000000..70758805ce0e198dc743ab1df97a8799265d36bb --- /dev/null +++ b/common/model-ratio.go @@ -0,0 +1,99 @@ +package common + +import ( + "encoding/json" + "strings" +) + +// ModelRatio +// https://platform.openai.com/docs/models/model-endpoint-compatibility +// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf +// https://openai.com/pricing +// TODO: when a new api is enabled, check the pricing here +// 1 === $0.002 / 1K tokens +// 1 === ¥0.014 / 1k tokens +var ModelRatio = map[string]float64{ + "gpt-4": 15, + "gpt-4-0314": 15, + "gpt-4-0613": 15, + "gpt-4-32k": 30, + "gpt-4-32k-0314": 30, + "gpt-4-32k-0613": 30, + "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-0301": 0.75, + "gpt-3.5-turbo-0613": 0.75, + "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens + "gpt-3.5-turbo-16k-0613": 1.5, + "text-ada-001": 0.2, + "text-babbage-001": 0.25, + "text-curie-001": 1, + "text-davinci-002": 10, + "text-davinci-003": 10, + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-ada-002": 0.05, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 0.1, + "text-moderation-latest": 0.1, + "dall-e": 8, + "claude-instant-1": 0.815, // $1.63 / 1M tokens + "claude-2": 5.51, // $11.02 / 1M tokens + "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens + "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens + "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens + "PaLM-2": 1, + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "qwen-v1": 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag + "qwen-plus-v1": 0.5715, // Same as above + "SparkDesk": 0.8572, // TBD + "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens + "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens + "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens + "360GPT_S2_V9.4": 0.8572, // ¥0.012 / 1k tokens +} + +func ModelRatio2JSONString() string { + jsonBytes, err := json.Marshal(ModelRatio) + if err != nil { + SysError("error marshalling model ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateModelRatioByJSONString(jsonStr string) error { + ModelRatio = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &ModelRatio) +} + +func GetModelRatio(name string) float64 { + ratio, ok := ModelRatio[name] + if !ok { + SysError("model ratio not found: " + name) + return 30 + } + return ratio +} + +func GetCompletionRatio(name string) float64 { + if strings.HasPrefix(name, "gpt-3.5") { + return 1.333333 + } + if strings.HasPrefix(name, "gpt-4") { + return 2 + } + if strings.HasPrefix(name, "claude-instant-1") { + return 3.38 + } + if strings.HasPrefix(name, "claude-2") { + return 2.965517 + } + return 1 +} diff --git a/common/rate-limit.go b/common/rate-limit.go new file mode 100644 index 0000000000000000000000000000000000000000..301c101c974809a660c4230843c2ae43289c6464 --- /dev/null +++ b/common/rate-limit.go @@ -0,0 +1,70 @@ +package common + +import ( + "sync" + "time" +) + +type InMemoryRateLimiter struct { + store map[string]*[]int64 + mutex sync.Mutex + expirationDuration time.Duration +} + +func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) { + if l.store == nil { + l.mutex.Lock() + if l.store == nil { + l.store = make(map[string]*[]int64) + l.expirationDuration = expirationDuration + if expirationDuration > 0 { + go l.clearExpiredItems() + } + } + l.mutex.Unlock() + } +} + +func (l *InMemoryRateLimiter) clearExpiredItems() { + for { + time.Sleep(l.expirationDuration) + l.mutex.Lock() + now := time.Now().Unix() + for key := range l.store { + queue := l.store[key] + size := len(*queue) + if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) { + delete(l.store, key) + } + } + l.mutex.Unlock() + } +} + +// Request parameter duration's unit is seconds +func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool { + l.mutex.Lock() + defer l.mutex.Unlock() + // [old <-- new] + queue, ok := l.store[key] + now := time.Now().Unix() + if ok { + if len(*queue) < maxRequestNum { + *queue = append(*queue, now) + return true + } else { + if now-(*queue)[0] >= duration { + *queue = (*queue)[1:] + *queue = append(*queue, now) + return true + } else { + return false + } + } + } else { + s := make([]int64, 0, maxRequestNum) + l.store[key] = &s + *(l.store[key]) = append(*(l.store[key]), now) + } + return true +} diff --git a/common/redis.go b/common/redis.go new file mode 100644 index 0000000000000000000000000000000000000000..12c477b84af40930d1fa1c6a11e93b4f5f8b733e --- /dev/null +++ b/common/redis.go @@ -0,0 +1,68 @@ +package common + +import ( + "context" + "github.com/go-redis/redis/v8" + "os" + "time" +) + +var RDB *redis.Client +var RedisEnabled = true + +// InitRedisClient This function is called after init() +func InitRedisClient() (err error) { + if os.Getenv("REDIS_CONN_STRING") == "" { + RedisEnabled = false + SysLog("REDIS_CONN_STRING not set, Redis is not enabled") + return nil + } + if os.Getenv("SYNC_FREQUENCY") == "" { + RedisEnabled = false + SysLog("SYNC_FREQUENCY not set, Redis is disabled") + return nil + } + SysLog("Redis is enabled") + opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) + if err != nil { + FatalLog("failed to parse Redis connection string: " + err.Error()) + } + RDB = redis.NewClient(opt) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err = RDB.Ping(ctx).Result() + if err != nil { + FatalLog("Redis ping test failed: " + err.Error()) + } + return err +} + +func ParseRedisOption() *redis.Options { + opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) + if err != nil { + FatalLog("failed to parse Redis connection string: " + err.Error()) + } + return opt +} + +func RedisSet(key string, value string, expiration time.Duration) error { + ctx := context.Background() + return RDB.Set(ctx, key, value, expiration).Err() +} + +func RedisGet(key string) (string, error) { + ctx := context.Background() + return RDB.Get(ctx, key).Result() +} + +func RedisDel(key string) error { + ctx := context.Background() + return RDB.Del(ctx, key).Err() +} + +func RedisDecrease(key string, value int64) error { + ctx := context.Background() + return RDB.DecrBy(ctx, key, value).Err() +} diff --git a/common/utils.go b/common/utils.go new file mode 100644 index 0000000000000000000000000000000000000000..bb9b7e0cbab5c03bddcee1aa00d9b54061a1c38c --- /dev/null +++ b/common/utils.go @@ -0,0 +1,192 @@ +package common + +import ( + "fmt" + "github.com/google/uuid" + "html/template" + "log" + "math/rand" + "net" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "time" +) + +func OpenBrowser(url string) { + var err error + + switch runtime.GOOS { + case "linux": + err = exec.Command("xdg-open", url).Start() + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + } + if err != nil { + log.Println(err) + } +} + +func GetIp() (ip string) { + ips, err := net.InterfaceAddrs() + if err != nil { + log.Println(err) + return ip + } + + for _, a := range ips { + if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + ip = ipNet.IP.String() + if strings.HasPrefix(ip, "10") { + return + } + if strings.HasPrefix(ip, "172") { + return + } + if strings.HasPrefix(ip, "192.168") { + return + } + ip = "" + } + } + } + return +} + +var sizeKB = 1024 +var sizeMB = sizeKB * 1024 +var sizeGB = sizeMB * 1024 + +func Bytes2Size(num int64) string { + numStr := "" + unit := "B" + if num/int64(sizeGB) > 1 { + numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) + unit = "GB" + } else if num/int64(sizeMB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) + unit = "MB" + } else if num/int64(sizeKB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) + unit = "KB" + } else { + numStr = fmt.Sprintf("%d", num) + } + return numStr + " " + unit +} + +func Seconds2Time(num int) (time string) { + if num/31104000 > 0 { + time += strconv.Itoa(num/31104000) + " 年 " + num %= 31104000 + } + if num/2592000 > 0 { + time += strconv.Itoa(num/2592000) + " 个月 " + num %= 2592000 + } + if num/86400 > 0 { + time += strconv.Itoa(num/86400) + " 天 " + num %= 86400 + } + if num/3600 > 0 { + time += strconv.Itoa(num/3600) + " 小时 " + num %= 3600 + } + if num/60 > 0 { + time += strconv.Itoa(num/60) + " 分钟 " + num %= 60 + } + time += strconv.Itoa(num) + " 秒" + return +} + +func Interface2String(inter interface{}) string { + switch inter.(type) { + case string: + return inter.(string) + case int: + return fmt.Sprintf("%d", inter.(int)) + case float64: + return fmt.Sprintf("%f", inter.(float64)) + } + return "Not Implemented" +} + +func UnescapeHTML(x string) interface{} { + return template.HTML(x) +} + +func IntMax(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func GetUUID() string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + return code +} + +const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func GenerateKey() string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, 48) + for i := 0; i < 16; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + uuid_ := GetUUID() + for i := 0; i < 32; i++ { + c := uuid_[i] + if i%2 == 0 && c >= 'a' && c <= 'z' { + c = c - 'a' + 'A' + } + key[i+16] = c + } + return string(key) +} + +func GetRandomString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func GetTimestamp() int64 { + return time.Now().Unix() +} + +func Max(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func GetOrDefault(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) + return defaultValue + } + return num +} diff --git a/common/validate.go b/common/validate.go new file mode 100644 index 0000000000000000000000000000000000000000..b3c78591078bd4a8ab321f9368fcc0721d8f60f0 --- /dev/null +++ b/common/validate.go @@ -0,0 +1,9 @@ +package common + +import "github.com/go-playground/validator/v10" + +var Validate *validator.Validate + +func init() { + Validate = validator.New() +} diff --git a/common/verification.go b/common/verification.go new file mode 100644 index 0000000000000000000000000000000000000000..d8ccd6eafca3e06adb6a5c90e7d5e15f201d5354 --- /dev/null +++ b/common/verification.go @@ -0,0 +1,77 @@ +package common + +import ( + "github.com/google/uuid" + "strings" + "sync" + "time" +) + +type verificationValue struct { + code string + time time.Time +} + +const ( + EmailVerificationPurpose = "v" + PasswordResetPurpose = "r" +) + +var verificationMutex sync.Mutex +var verificationMap map[string]verificationValue +var verificationMapMaxSize = 10 +var VerificationValidMinutes = 10 + +func GenerateVerificationCode(length int) string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + if length == 0 { + return code + } + return code[:length] +} + +func RegisterVerificationCodeWithKey(key string, code string, purpose string) { + verificationMutex.Lock() + defer verificationMutex.Unlock() + verificationMap[purpose+key] = verificationValue{ + code: code, + time: time.Now(), + } + if len(verificationMap) > verificationMapMaxSize { + removeExpiredPairs() + } +} + +func VerifyCodeWithKey(key string, code string, purpose string) bool { + verificationMutex.Lock() + defer verificationMutex.Unlock() + value, okay := verificationMap[purpose+key] + now := time.Now() + if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 { + return false + } + return code == value.code +} + +func DeleteKey(key string, purpose string) { + verificationMutex.Lock() + defer verificationMutex.Unlock() + delete(verificationMap, purpose+key) +} + +// no lock inside, so the caller must lock the verificationMap before calling! +func removeExpiredPairs() { + now := time.Now() + for key := range verificationMap { + if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 { + delete(verificationMap, key) + } + } +} + +func init() { + verificationMutex.Lock() + defer verificationMutex.Unlock() + verificationMap = make(map[string]verificationValue) +} diff --git a/controller/billing.go b/controller/billing.go new file mode 100644 index 0000000000000000000000000000000000000000..79eae1e242f13eca95ddf3f094db8a78330b0596 --- /dev/null +++ b/controller/billing.go @@ -0,0 +1,91 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "one-api/common" + "one-api/model" +) + +func GetSubscription(c *gin.Context) { + var remainQuota int + var usedQuota int + var err error + var token *model.Token + var expiredTime int64 + if common.DisplayTokenStatEnabled { + tokenId := c.GetInt("token_id") + token, err = model.GetTokenById(tokenId) + expiredTime = token.ExpiredTime + remainQuota = token.RemainQuota + usedQuota = token.UsedQuota + } else { + userId := c.GetInt("id") + remainQuota, err = model.GetUserQuota(userId) + usedQuota, err = model.GetUserUsedQuota(userId) + } + if expiredTime <= 0 { + expiredTime = 0 + } + if err != nil { + openAIError := OpenAIError{ + Message: err.Error(), + Type: "one_api_error", + } + c.JSON(200, gin.H{ + "error": openAIError, + }) + return + } + quota := remainQuota + usedQuota + amount := float64(quota) + if common.DisplayInCurrencyEnabled { + amount /= common.QuotaPerUnit + } + if token != nil && token.UnlimitedQuota { + amount = 100000000 + } + subscription := OpenAISubscriptionResponse{ + Object: "billing_subscription", + HasPaymentMethod: true, + SoftLimitUSD: amount, + HardLimitUSD: amount, + SystemHardLimitUSD: amount, + AccessUntil: expiredTime, + } + c.JSON(200, subscription) + return +} + +func GetUsage(c *gin.Context) { + var quota int + var err error + var token *model.Token + if common.DisplayTokenStatEnabled { + tokenId := c.GetInt("token_id") + token, err = model.GetTokenById(tokenId) + quota = token.UsedQuota + } else { + userId := c.GetInt("id") + quota, err = model.GetUserUsedQuota(userId) + } + if err != nil { + openAIError := OpenAIError{ + Message: err.Error(), + Type: "one_api_error", + } + c.JSON(200, gin.H{ + "error": openAIError, + }) + return + } + amount := float64(quota) + if common.DisplayInCurrencyEnabled { + amount /= common.QuotaPerUnit + } + usage := OpenAIUsageResponse{ + Object: "list", + TotalUsage: amount * 100, + } + c.JSON(200, usage) + return +} diff --git a/controller/channel-billing.go b/controller/channel-billing.go new file mode 100644 index 0000000000000000000000000000000000000000..46262f6c95d18ee3a2f9d7b8f5e44bced8f2601a --- /dev/null +++ b/controller/channel-billing.go @@ -0,0 +1,345 @@ +package controller + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "time" + + "github.com/gin-gonic/gin" +) + +// https://github.com/songquanpeng/one-api/issues/79 + +type OpenAISubscriptionResponse struct { + Object string `json:"object"` + HasPaymentMethod bool `json:"has_payment_method"` + SoftLimitUSD float64 `json:"soft_limit_usd"` + HardLimitUSD float64 `json:"hard_limit_usd"` + SystemHardLimitUSD float64 `json:"system_hard_limit_usd"` + AccessUntil int64 `json:"access_until"` +} + +type OpenAIUsageDailyCost struct { + Timestamp float64 `json:"timestamp"` + LineItems []struct { + Name string `json:"name"` + Cost float64 `json:"cost"` + } +} + +type OpenAICreditGrants struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalAvailable float64 `json:"total_available"` +} + +type OpenAIUsageResponse struct { + Object string `json:"object"` + //DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"` + TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar +} + +type OpenAISBUsageResponse struct { + Msg string `json:"msg"` + Data *struct { + Credit string `json:"credit"` + } `json:"data"` +} + +type AIProxyUserOverviewResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + ErrorCode int `json:"error_code"` + Data struct { + TotalPoints float64 `json:"totalPoints"` + } `json:"data"` +} + +type API2GPTUsageResponse struct { + Object string `json:"object"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` + TotalRemaining float64 `json:"total_remaining"` +} + +type APGC2DGPTUsageResponse struct { + //Grants interface{} `json:"grants"` + Object string `json:"object"` + TotalAvailable float64 `json:"total_available"` + TotalGranted float64 `json:"total_granted"` + TotalUsed float64 `json:"total_used"` +} + +// GetAuthHeader get auth header +func GetAuthHeader(token string) http.Header { + h := http.Header{} + h.Add("Authorization", fmt.Sprintf("Bearer %s", token)) + return h +} + +func GetResponseBody(method, url string, channel *model.Channel, headers http.Header) ([]byte, error) { + req, err := http.NewRequest(method, url, nil) + if err != nil { + return nil, err + } + for k := range headers { + req.Header.Add(k, headers.Get(k)) + } + res, err := httpClient.Do(req) + if err != nil { + return nil, err + } + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("status code: %d", res.StatusCode) + } + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + err = res.Body.Close() + if err != nil { + return nil, err + } + return body, nil +} + +func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { + url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + + if err != nil { + return 0, err + } + response := OpenAICreditGrants{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + channel.UpdateBalance(response.TotalAvailable) + return response.TotalAvailable, nil +} + +func updateChannelOpenAISBBalance(channel *model.Channel) (float64, error) { + url := fmt.Sprintf("https://api.openai-sb.com/sb-api/user/status?api_key=%s", channel.Key) + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := OpenAISBUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if response.Data == nil { + return 0, errors.New(response.Msg) + } + balance, err := strconv.ParseFloat(response.Data.Credit, 64) + if err != nil { + return 0, err + } + channel.UpdateBalance(balance) + return balance, nil +} + +func updateChannelAIProxyBalance(channel *model.Channel) (float64, error) { + url := "https://aiproxy.io/api/report/getUserOverview" + headers := http.Header{} + headers.Add("Api-Key", channel.Key) + body, err := GetResponseBody("GET", url, channel, headers) + if err != nil { + return 0, err + } + response := AIProxyUserOverviewResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if !response.Success { + return 0, fmt.Errorf("code: %d, message: %s", response.ErrorCode, response.Message) + } + channel.UpdateBalance(response.Data.TotalPoints) + return response.Data.TotalPoints, nil +} + +func updateChannelAPI2GPTBalance(channel *model.Channel) (float64, error) { + url := "https://api.api2gpt.com/dashboard/billing/credit_grants" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + + if err != nil { + return 0, err + } + response := API2GPTUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + channel.UpdateBalance(response.TotalRemaining) + return response.TotalRemaining, nil +} + +func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { + url := "https://api.aigc2d.com/dashboard/billing/credit_grants" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + response := APGC2DGPTUsageResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + channel.UpdateBalance(response.TotalAvailable) + return response.TotalAvailable, nil +} + +func updateChannelBalance(channel *model.Channel) (float64, error) { + baseURL := common.ChannelBaseURLs[channel.Type] + if channel.BaseURL == "" { + channel.BaseURL = baseURL + } + switch channel.Type { + case common.ChannelTypeOpenAI: + if channel.BaseURL != "" { + baseURL = channel.BaseURL + } + case common.ChannelTypeAzure: + return 0, errors.New("尚未实现") + case common.ChannelTypeCustom: + baseURL = channel.BaseURL + case common.ChannelTypeCloseAI: + return updateChannelCloseAIBalance(channel) + case common.ChannelTypeOpenAISB: + return updateChannelOpenAISBBalance(channel) + case common.ChannelTypeAIProxy: + return updateChannelAIProxyBalance(channel) + case common.ChannelTypeAPI2GPT: + return updateChannelAPI2GPTBalance(channel) + case common.ChannelTypeAIGC2D: + return updateChannelAIGC2DBalance(channel) + default: + return 0, errors.New("尚未实现") + } + url := fmt.Sprintf("%s/v1/dashboard/billing/subscription", baseURL) + + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + subscription := OpenAISubscriptionResponse{} + err = json.Unmarshal(body, &subscription) + if err != nil { + return 0, err + } + now := time.Now() + startDate := fmt.Sprintf("%s-01", now.Format("2006-01")) + endDate := now.Format("2006-01-02") + if !subscription.HasPaymentMethod { + startDate = now.AddDate(0, 0, -100).Format("2006-01-02") + } + url = fmt.Sprintf("%s/v1/dashboard/billing/usage?start_date=%s&end_date=%s", baseURL, startDate, endDate) + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + usage := OpenAIUsageResponse{} + err = json.Unmarshal(body, &usage) + if err != nil { + return 0, err + } + balance := subscription.HardLimitUSD - usage.TotalUsage/100 + channel.UpdateBalance(balance) + return balance, nil +} + +func UpdateChannelBalance(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + balance, err := updateChannelBalance(channel) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "balance": balance, + }) + return +} + +func updateAllChannelsBalance() error { + channels, err := model.GetAllChannels(0, 0, true) + if err != nil { + return err + } + for _, channel := range channels { + if channel.Status != common.ChannelStatusEnabled { + continue + } + // TODO: support Azure + if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { + continue + } + balance, err := updateChannelBalance(channel) + if err != nil { + continue + } else { + // err is nil & balance <= 0 means quota is used up + if balance <= 0 { + disableChannel(channel.Id, channel.Name, "余额不足") + } + } + time.Sleep(common.RequestInterval) + } + return nil +} + +func UpdateAllChannelsBalance(c *gin.Context) { + // TODO: make it async + err := updateAllChannelsBalance() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func AutomaticallyUpdateChannels(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Minute) + common.SysLog("updating all channels") + _ = updateAllChannelsBalance() + common.SysLog("channels update done") + } +} diff --git a/controller/channel-test.go b/controller/channel-test.go new file mode 100644 index 0000000000000000000000000000000000000000..686521eff6efae9932c5ab29d7342ec12a8687ac --- /dev/null +++ b/controller/channel-test.go @@ -0,0 +1,223 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "sync" + "time" +) + +func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { + switch channel.Type { + case common.ChannelTypePaLM: + fallthrough + case common.ChannelTypeAnthropic: + fallthrough + case common.ChannelTypeBaidu: + fallthrough + case common.ChannelTypeZhipu: + fallthrough + case common.ChannelTypeAli: + fallthrough + case common.ChannelType360: + fallthrough + case common.ChannelTypeXunfei: + return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil + case common.ChannelTypeAzure: + request.Model = "gpt-35-turbo" + default: + request.Model = "gpt-3.5-turbo" + } + requestURL := common.ChannelBaseURLs[channel.Type] + if channel.Type == common.ChannelTypeAzure { + requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) + } else { + if channel.BaseURL != "" { + requestURL = channel.BaseURL + } + requestURL += "/v1/chat/completions" + } + + jsonData, err := json.Marshal(request) + if err != nil { + return err, nil + } + req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) + if err != nil { + return err, nil + } + if channel.Type == common.ChannelTypeAzure { + req.Header.Set("api-key", channel.Key) + } else { + req.Header.Set("Authorization", "Bearer "+channel.Key) + } + req.Header.Set("Content-Type", "application/json") + resp, err := httpClient.Do(req) + if err != nil { + return err, nil + } + defer resp.Body.Close() + var response TextResponse + err = json.NewDecoder(resp.Body).Decode(&response) + if err != nil { + return err, nil + } + if response.Usage.CompletionTokens == 0 { + return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error + } + return nil, nil +} + +func buildTestRequest() *ChatRequest { + testRequest := &ChatRequest{ + Model: "", // this will be set later + MaxTokens: 1, + } + testMessage := Message{ + Role: "user", + Content: "hi", + } + testRequest.Messages = append(testRequest.Messages, testMessage) + return testRequest +} + +func TestChannel(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + testRequest := buildTestRequest() + tik := time.Now() + err, _ = testChannel(channel, *testRequest) + tok := time.Now() + milliseconds := tok.Sub(tik).Milliseconds() + go channel.UpdateResponseTime(milliseconds) + consumedTime := float64(milliseconds) / 1000.0 + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + "time": consumedTime, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "time": consumedTime, + }) + return +} + +var testAllChannelsLock sync.Mutex +var testAllChannelsRunning bool = false + +// disable & notify +func disableChannel(channelId int, channelName string, reason string) { + if common.RootUserEmail == "" { + common.RootUserEmail = model.GetRootUserEmail() + } + model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + err := common.SendEmail(subject, common.RootUserEmail, content) + if err != nil { + common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } +} + +func testAllChannels(notify bool) error { + if common.RootUserEmail == "" { + common.RootUserEmail = model.GetRootUserEmail() + } + testAllChannelsLock.Lock() + if testAllChannelsRunning { + testAllChannelsLock.Unlock() + return errors.New("测试已在运行中") + } + testAllChannelsRunning = true + testAllChannelsLock.Unlock() + channels, err := model.GetAllChannels(0, 0, true) + if err != nil { + return err + } + testRequest := buildTestRequest() + var disableThreshold = int64(common.ChannelDisableThreshold * 1000) + if disableThreshold == 0 { + disableThreshold = 10000000 // a impossible value + } + go func() { + for _, channel := range channels { + if channel.Status != common.ChannelStatusEnabled { + continue + } + tik := time.Now() + err, openaiErr := testChannel(channel, *testRequest) + tok := time.Now() + milliseconds := tok.Sub(tik).Milliseconds() + if milliseconds > disableThreshold { + err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + disableChannel(channel.Id, channel.Name, err.Error()) + } + if shouldDisableChannel(openaiErr, -1) { + disableChannel(channel.Id, channel.Name, err.Error()) + } + channel.UpdateResponseTime(milliseconds) + time.Sleep(common.RequestInterval) + } + testAllChannelsLock.Lock() + testAllChannelsRunning = false + testAllChannelsLock.Unlock() + if notify { + err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + if err != nil { + common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } + } + }() + return nil +} + +func TestAllChannels(c *gin.Context) { + err := testAllChannels(true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func AutomaticallyTestChannels(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Minute) + common.SysLog("testing all channels") + _ = testAllChannels(false) + common.SysLog("channel test finished") + } +} diff --git a/controller/channel.go b/controller/channel.go new file mode 100644 index 0000000000000000000000000000000000000000..8afc0eedbaab5b98e5667b6a08181fbf01e9b713 --- /dev/null +++ b/controller/channel.go @@ -0,0 +1,154 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "strings" +) + +func GetAllChannels(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channels, + }) + return +} + +func SearchChannels(c *gin.Context) { + keyword := c.Query("keyword") + channels, err := model.SearchChannels(keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channels, + }) + return +} + +func GetChannel(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel, err := model.GetChannelById(id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channel, + }) + return +} + +func AddChannel(c *gin.Context) { + channel := model.Channel{} + err := c.ShouldBindJSON(&channel) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + channel.CreatedTime = common.GetTimestamp() + keys := strings.Split(channel.Key, "\n") + channels := make([]model.Channel, 0) + for _, key := range keys { + if key == "" { + continue + } + localChannel := channel + localChannel.Key = key + channels = append(channels, localChannel) + } + err = model.BatchInsertChannels(channels) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func DeleteChannel(c *gin.Context) { + id, _ := strconv.Atoi(c.Param("id")) + channel := model.Channel{Id: id} + err := channel.Delete() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func UpdateChannel(c *gin.Context) { + channel := model.Channel{} + err := c.ShouldBindJSON(&channel) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = channel.Update() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": channel, + }) + return +} diff --git a/controller/github.go b/controller/github.go new file mode 100644 index 0000000000000000000000000000000000000000..e1c64130814b9d00301f00fbbf61a1af1eac7629 --- /dev/null +++ b/controller/github.go @@ -0,0 +1,207 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "time" +) + +type GitHubOAuthResponse struct { + AccessToken string `json:"access_token"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +type GitHubUser struct { + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` +} + +func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + common.SysLog(err.Error()) + return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") + } + defer res.Body.Close() + var oAuthResponse GitHubOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oAuthResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", "https://api.github.com/user", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) + res2, err := client.Do(req) + if err != nil { + common.SysLog(err.Error()) + return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") + } + defer res2.Body.Close() + var githubUser GitHubUser + err = json.NewDecoder(res2.Body).Decode(&githubUser) + if err != nil { + return nil, err + } + if githubUser.Login == "" { + return nil, errors.New("返回值非法,用户字段为空,请稍后重试!") + } + return &githubUser, nil +} + +func GitHubOAuth(c *gin.Context) { + session := sessions.Default(c) + username := session.Get("username") + if username != nil { + GitHubBind(c) + return + } + + if !common.GitHubOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 GitHub 登录以及注册", + }) + return + } + code := c.Query("code") + githubUser, err := getGitHubUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + GitHubId: githubUser.Login, + } + if model.IsGitHubIdAlreadyTaken(user.GitHubId) { + err := user.FillUserByGitHubId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if common.RegisterEnabled { + user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) + if githubUser.Name != "" { + user.DisplayName = githubUser.Name + } else { + user.DisplayName = "GitHub User" + } + user.Email = githubUser.Email + user.Role = common.RoleCommonUser + user.Status = common.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != common.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + setupLogin(&user, c) +} + +func GitHubBind(c *gin.Context) { + if !common.GitHubOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 GitHub 登录以及注册", + }) + return + } + code := c.Query("code") + githubUser, err := getGitHubUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + GitHubId: githubUser.Login, + } + if model.IsGitHubIdAlreadyTaken(user.GitHubId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 GitHub 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.GitHubId = githubUser.Login + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/group.go b/controller/group.go new file mode 100644 index 0000000000000000000000000000000000000000..2b2f6006fad4bd7ce9ecfdec174b47ee34caf8e0 --- /dev/null +++ b/controller/group.go @@ -0,0 +1,19 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" +) + +func GetGroups(c *gin.Context) { + groupNames := make([]string, 0) + for groupName, _ := range common.GroupRatio { + groupNames = append(groupNames, groupName) + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": groupNames, + }) +} diff --git a/controller/log.go b/controller/log.go new file mode 100644 index 0000000000000000000000000000000000000000..ba04334976f83099b63a2df2354470e829566ffa --- /dev/null +++ b/controller/log.go @@ -0,0 +1,133 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "one-api/common" + "one-api/model" + "strconv" +) + +func GetAllLogs(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + logType, _ := strconv.Atoi(c.Query("type")) + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + username := c.Query("username") + tokenName := c.Query("token_name") + modelName := c.Query("model_name") + logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) + if err != nil { + c.JSON(200, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} + +func GetUserLogs(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + userId := c.GetInt("id") + logType, _ := strconv.Atoi(c.Query("type")) + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + tokenName := c.Query("token_name") + modelName := c.Query("model_name") + logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) + if err != nil { + c.JSON(200, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} + +func SearchAllLogs(c *gin.Context) { + keyword := c.Query("keyword") + logs, err := model.SearchAllLogs(keyword) + if err != nil { + c.JSON(200, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} + +func SearchUserLogs(c *gin.Context) { + keyword := c.Query("keyword") + userId := c.GetInt("id") + logs, err := model.SearchUserLogs(userId, keyword) + if err != nil { + c.JSON(200, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} + +func GetLogsStat(c *gin.Context) { + logType, _ := strconv.Atoi(c.Query("type")) + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + tokenName := c.Query("token_name") + username := c.Query("username") + modelName := c.Query("model_name") + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "quota": quotaNum, + //"token": tokenNum, + }, + }) +} + +func GetLogsSelfStat(c *gin.Context) { + username := c.GetString("username") + logType, _ := strconv.Atoi(c.Query("type")) + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + tokenName := c.Query("token_name") + modelName := c.Query("model_name") + quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "quota": quotaNum, + //"token": tokenNum, + }, + }) +} diff --git a/controller/misc.go b/controller/misc.go new file mode 100644 index 0000000000000000000000000000000000000000..2bcbb41f056e2c37a8f4e3673e6597aec0cb1e1e --- /dev/null +++ b/controller/misc.go @@ -0,0 +1,204 @@ +package controller + +import ( + "encoding/json" + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "strings" + + "github.com/gin-gonic/gin" +) + +func GetStatus(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "version": common.Version, + "start_time": common.StartTime, + "email_verification": common.EmailVerificationEnabled, + "github_oauth": common.GitHubOAuthEnabled, + "github_client_id": common.GitHubClientId, + "system_name": common.SystemName, + "logo": common.Logo, + "footer_html": common.Footer, + "wechat_qrcode": common.WeChatAccountQRCodeImageURL, + "wechat_login": common.WeChatAuthEnabled, + "server_address": common.ServerAddress, + "turnstile_check": common.TurnstileCheckEnabled, + "turnstile_site_key": common.TurnstileSiteKey, + "top_up_link": common.TopUpLink, + "chat_link": common.ChatLink, + "quota_per_unit": common.QuotaPerUnit, + "display_in_currency": common.DisplayInCurrencyEnabled, + }, + }) + return +} + +func GetNotice(c *gin.Context) { + common.OptionMapRWMutex.RLock() + defer common.OptionMapRWMutex.RUnlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": common.OptionMap["Notice"], + }) + return +} + +func GetAbout(c *gin.Context) { + common.OptionMapRWMutex.RLock() + defer common.OptionMapRWMutex.RUnlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": common.OptionMap["About"], + }) + return +} + +func GetHomePageContent(c *gin.Context) { + common.OptionMapRWMutex.RLock() + defer common.OptionMapRWMutex.RUnlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": common.OptionMap["HomePageContent"], + }) + return +} + +func SendEmailVerification(c *gin.Context) { + email := c.Query("email") + if err := common.Validate.Var(email, "required,email"); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if common.EmailDomainRestrictionEnabled { + allowed := false + for _, domain := range common.EmailDomainWhitelist { + if strings.HasSuffix(email, "@"+domain) { + allowed = true + break + } + } + if !allowed { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中", + }) + return + } + } + if model.IsEmailAlreadyTaken(email) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "邮箱地址已被占用", + }) + return + } + code := common.GenerateVerificationCode(6) + common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) + subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) + content := fmt.Sprintf("

您好,你正在进行%s邮箱验证。

"+ + "

您的验证码为: %s

"+ + "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, code, common.VerificationValidMinutes) + err := common.SendEmail(subject, email, content) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func SendPasswordResetEmail(c *gin.Context) { + email := c.Query("email") + if err := common.Validate.Var(email, "required,email"); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if !model.IsEmailAlreadyTaken(email) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该邮箱地址未注册", + }) + return + } + code := common.GenerateVerificationCode(0) + common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) + link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) + subject := fmt.Sprintf("%s密码重置", common.SystemName) + content := fmt.Sprintf("

您好,你正在进行%s密码重置。

"+ + "

点击 此处 进行密码重置。

"+ + "

如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s

"+ + "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

", common.SystemName, link, link, common.VerificationValidMinutes) + err := common.SendEmail(subject, email, content) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +type PasswordResetRequest struct { + Email string `json:"email"` + Token string `json:"token"` +} + +func ResetPassword(c *gin.Context) { + var req PasswordResetRequest + err := json.NewDecoder(c.Request.Body).Decode(&req) + if req.Email == "" || req.Token == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if !common.VerifyCodeWithKey(req.Email, req.Token, common.PasswordResetPurpose) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "重置链接非法或已过期", + }) + return + } + password := common.GenerateVerificationCode(12) + err = model.ResetUserPasswordByEmail(req.Email, password) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + common.DeleteKey(req.Email, common.PasswordResetPurpose) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": password, + }) + return +} diff --git a/controller/model.go b/controller/model.go new file mode 100644 index 0000000000000000000000000000000000000000..88f95f7b4b6dc4b5be5dcc8141cd60f374b02888 --- /dev/null +++ b/controller/model.go @@ -0,0 +1,446 @@ +package controller + +import ( + "fmt" + + "github.com/gin-gonic/gin" +) + +// https://platform.openai.com/docs/api-reference/models/list + +type OpenAIModelPermission struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group *string `json:"group"` + IsBlocking bool `json:"is_blocking"` +} + +type OpenAIModels struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []OpenAIModelPermission `json:"permission"` + Root string `json:"root"` + Parent *string `json:"parent"` +} + +var openAIModels []OpenAIModels +var openAIModelsMap map[string]OpenAIModels + +func init() { + var permission []OpenAIModelPermission + permission = append(permission, OpenAIModelPermission{ + Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ", + Object: "model_permission", + Created: 1626777600, + AllowCreateEngine: true, + AllowSampling: true, + AllowLogprobs: true, + AllowSearchIndices: false, + AllowView: true, + AllowFineTuning: false, + Organization: "*", + Group: nil, + IsBlocking: false, + }) + // https://platform.openai.com/docs/models/model-endpoint-compatibility + openAIModels = []OpenAIModels{ + { + Id: "dall-e", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "dall-e", + Parent: nil, + }, + { + Id: "whisper-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "whisper-1", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo-0301", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-0301", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo-0613", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-0613", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo-16k", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-16k", + Parent: nil, + }, + { + Id: "gpt-3.5-turbo-16k-0613", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-16k-0613", + Parent: nil, + }, + { + Id: "gpt-4", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4", + Parent: nil, + }, + { + Id: "gpt-4-0314", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-0314", + Parent: nil, + }, + { + Id: "gpt-4-0613", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-0613", + Parent: nil, + }, + { + Id: "gpt-4-32k", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-32k", + Parent: nil, + }, + { + Id: "gpt-4-32k-0314", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-32k-0314", + Parent: nil, + }, + { + Id: "gpt-4-32k-0613", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-32k-0613", + Parent: nil, + }, + { + Id: "text-embedding-ada-002", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-embedding-ada-002", + Parent: nil, + }, + { + Id: "text-davinci-003", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-davinci-003", + Parent: nil, + }, + { + Id: "text-davinci-002", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-davinci-002", + Parent: nil, + }, + { + Id: "text-curie-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-curie-001", + Parent: nil, + }, + { + Id: "text-babbage-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-babbage-001", + Parent: nil, + }, + { + Id: "text-ada-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-ada-001", + Parent: nil, + }, + { + Id: "text-moderation-latest", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-moderation-latest", + Parent: nil, + }, + { + Id: "text-moderation-stable", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-moderation-stable", + Parent: nil, + }, + { + Id: "text-davinci-edit-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "text-davinci-edit-001", + Parent: nil, + }, + { + Id: "code-davinci-edit-001", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "code-davinci-edit-001", + Parent: nil, + }, + { + Id: "claude-instant-1", + Object: "model", + Created: 1677649963, + OwnedBy: "anturopic", + Permission: permission, + Root: "claude-instant-1", + Parent: nil, + }, + { + Id: "claude-2", + Object: "model", + Created: 1677649963, + OwnedBy: "anturopic", + Permission: permission, + Root: "claude-2", + Parent: nil, + }, + { + Id: "ERNIE-Bot", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "ERNIE-Bot", + Parent: nil, + }, + { + Id: "ERNIE-Bot-turbo", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "ERNIE-Bot-turbo", + Parent: nil, + }, + { + Id: "Embedding-V1", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "Embedding-V1", + Parent: nil, + }, + { + Id: "PaLM-2", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "PaLM-2", + Parent: nil, + }, + { + Id: "chatglm_pro", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu", + Permission: permission, + Root: "chatglm_pro", + Parent: nil, + }, + { + Id: "chatglm_std", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu", + Permission: permission, + Root: "chatglm_std", + Parent: nil, + }, + { + Id: "chatglm_lite", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu", + Permission: permission, + Root: "chatglm_lite", + Parent: nil, + }, + { + Id: "qwen-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-v1", + Parent: nil, + }, + { + Id: "qwen-plus-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-plus-v1", + Parent: nil, + }, + { + Id: "SparkDesk", + Object: "model", + Created: 1677649963, + OwnedBy: "xunfei", + Permission: permission, + Root: "SparkDesk", + Parent: nil, + }, + { + Id: "360GPT_S2_V9", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "360GPT_S2_V9", + Parent: nil, + }, + { + Id: "embedding-bert-512-v1", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "embedding-bert-512-v1", + Parent: nil, + }, + { + Id: "embedding_s1_v1", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "embedding_s1_v1", + Parent: nil, + }, + { + Id: "semantic_similarity_s1_v1", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "semantic_similarity_s1_v1", + Parent: nil, + }, + { + Id: "360GPT_S2_V9.4", + Object: "model", + Created: 1677649963, + OwnedBy: "360", + Permission: permission, + Root: "360GPT_S2_V9.4", + Parent: nil, + }, + } + openAIModelsMap = make(map[string]OpenAIModels) + for _, model := range openAIModels { + openAIModelsMap[model.Id] = model + } +} + +func ListModels(c *gin.Context) { + c.JSON(200, gin.H{ + "object": "list", + "data": openAIModels, + }) +} + +func RetrieveModel(c *gin.Context) { + modelId := c.Param("model") + if model, ok := openAIModelsMap[modelId]; ok { + c.JSON(200, model) + } else { + openAIError := OpenAIError{ + Message: fmt.Sprintf("The model '%s' does not exist", modelId), + Type: "invalid_request_error", + Param: "model", + Code: "model_not_found", + } + c.JSON(200, gin.H{ + "error": openAIError, + }) + } +} diff --git a/controller/option.go b/controller/option.go new file mode 100644 index 0000000000000000000000000000000000000000..9cf4ff1b24c9769d69904e7d85f790a097fc32ad --- /dev/null +++ b/controller/option.go @@ -0,0 +1,91 @@ +package controller + +import ( + "encoding/json" + "net/http" + "one-api/common" + "one-api/model" + "strings" + + "github.com/gin-gonic/gin" +) + +func GetOptions(c *gin.Context) { + var options []*model.Option + common.OptionMapRWMutex.Lock() + for k, v := range common.OptionMap { + if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { + continue + } + options = append(options, &model.Option{ + Key: k, + Value: common.Interface2String(v), + }) + } + common.OptionMapRWMutex.Unlock() + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": options, + }) + return +} + +func UpdateOption(c *gin.Context) { + var option model.Option + err := json.NewDecoder(c.Request.Body).Decode(&option) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + switch option.Key { + case "GitHubOAuthEnabled": + if option.Value == "true" && common.GitHubClientId == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!", + }) + return + } + case "EmailDomainRestrictionEnabled": + if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", + }) + return + } + case "WeChatAuthEnabled": + if option.Value == "true" && common.WeChatServerAddress == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用微信登录,请先填入微信登录相关配置信息!", + }) + return + } + case "TurnstileCheckEnabled": + if option.Value == "true" && common.TurnstileSiteKey == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", + }) + return + } + } + err = model.UpdateOption(option.Key, option.Value) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/controller/redemption.go b/controller/redemption.go new file mode 100644 index 0000000000000000000000000000000000000000..0f656be0e8da07146ea54d5789c5ae58851602b1 --- /dev/null +++ b/controller/redemption.go @@ -0,0 +1,192 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/model" + "strconv" +) + +func GetAllRedemptions(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": redemptions, + }) + return +} + +func SearchRedemptions(c *gin.Context) { + keyword := c.Query("keyword") + redemptions, err := model.SearchRedemptions(keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": redemptions, + }) + return +} + +func GetRedemption(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + redemption, err := model.GetRedemptionById(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": redemption, + }) + return +} + +func AddRedemption(c *gin.Context) { + redemption := model.Redemption{} + err := c.ShouldBindJSON(&redemption) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if len(redemption.Name) == 0 || len(redemption.Name) > 20 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "兑换码名称长度必须在1-20之间", + }) + return + } + if redemption.Count <= 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "兑换码个数必须大于0", + }) + return + } + if redemption.Count > 100 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "一次兑换码批量生成的个数不能大于 100", + }) + return + } + var keys []string + for i := 0; i < redemption.Count; i++ { + key := common.GetUUID() + cleanRedemption := model.Redemption{ + UserId: c.GetInt("id"), + Name: redemption.Name, + Key: key, + CreatedTime: common.GetTimestamp(), + Quota: redemption.Quota, + } + err = cleanRedemption.Insert() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + "data": keys, + }) + return + } + keys = append(keys, key) + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": keys, + }) + return +} + +func DeleteRedemption(c *gin.Context) { + id, _ := strconv.Atoi(c.Param("id")) + err := model.DeleteRedemptionById(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func UpdateRedemption(c *gin.Context) { + statusOnly := c.Query("status_only") + redemption := model.Redemption{} + err := c.ShouldBindJSON(&redemption) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + cleanRedemption, err := model.GetRedemptionById(redemption.Id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if statusOnly != "" { + cleanRedemption.Status = redemption.Status + } else { + // If you add more fields, please also update redemption.Update() + cleanRedemption.Name = redemption.Name + cleanRedemption.Quota = redemption.Quota + } + err = cleanRedemption.Update() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": cleanRedemption, + }) + return +} diff --git a/controller/relay-ali.go b/controller/relay-ali.go new file mode 100644 index 0000000000000000000000000000000000000000..9dca9a898bf4d485ac4960a928404c0cdc260886 --- /dev/null +++ b/controller/relay-ali.go @@ -0,0 +1,241 @@ +package controller + +import ( + "bufio" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strings" +) + +// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r + +type AliMessage struct { + User string `json:"user"` + Bot string `json:"bot"` +} + +type AliInput struct { + Prompt string `json:"prompt"` + History []AliMessage `json:"history"` +} + +type AliParameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` +} + +type AliChatRequest struct { + Model string `json:"model"` + Input AliInput `json:"input"` + Parameters AliParameters `json:"parameters,omitempty"` +} + +type AliError struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type AliUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type AliOutput struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` +} + +type AliChatResponse struct { + Output AliOutput `json:"output"` + Usage AliUsage `json:"usage"` + AliError +} + +func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { + messages := make([]AliMessage, 0, len(request.Messages)) + prompt := "" + for i := 0; i < len(request.Messages); i++ { + message := request.Messages[i] + if message.Role == "system" { + messages = append(messages, AliMessage{ + User: message.Content, + Bot: "Okay", + }) + continue + } else { + if i == len(request.Messages)-1 { + prompt = message.Content + break + } + messages = append(messages, AliMessage{ + User: message.Content, + Bot: request.Messages[i+1].Content, + }) + i++ + } + } + return &AliChatRequest{ + Model: request.Model, + Input: AliInput{ + Prompt: prompt, + History: messages, + }, + //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's + // TopP: request.TopP, + // TopK: 50, + // //Seed: 0, + // //EnableSearch: false, + //}, + } +} + +func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: response.Output.Text, + }, + FinishReason: response.Output.FinishReason, + } + fullTextResponse := OpenAITextResponse{ + Id: response.RequestId, + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + Usage: Usage{ + PromptTokens: response.Usage.InputTokens, + CompletionTokens: response.Usage.OutputTokens, + TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, + }, + } + return &fullTextResponse +} + +func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = aliResponse.Output.Text + if aliResponse.Output.FinishReason != "null" { + finishReason := aliResponse.Output.FinishReason + choice.FinishReason = &finishReason + } + response := ChatCompletionsStreamResponse{ + Id: aliResponse.RequestId, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "ernie-bot", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + lastResponseText := "" + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var aliResponse AliChatResponse + err := json.Unmarshal([]byte(data), &aliResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if aliResponse.Usage.OutputTokens != 0 { + usage.PromptTokens = aliResponse.Usage.InputTokens + usage.CompletionTokens = aliResponse.Usage.OutputTokens + usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens + } + response := streamResponseAli2OpenAI(&aliResponse) + response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) + lastResponseText = aliResponse.Output.Text + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var aliResponse AliChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if aliResponse.Code != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: aliResponse.Message, + Type: aliResponse.Code, + Param: aliResponse.RequestId, + Code: aliResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseAli2OpenAI(&aliResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-audio.go b/controller/relay-audio.go new file mode 100644 index 0000000000000000000000000000000000000000..277ab4048a2c17de87f27950b1740efbe84ab84c --- /dev/null +++ b/controller/relay-audio.go @@ -0,0 +1,147 @@ +package controller + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { + audioModel := "whisper-1" + + tokenId := c.GetInt("token_id") + channelType := c.GetInt("channel") + userId := c.GetInt("id") + group := c.GetString("group") + + preConsumedTokens := common.PreConsumedQuota + modelRatio := common.GetModelRatio(audioModel) + groupRatio := common.GetGroupRatio(group) + ratio := modelRatio * groupRatio + preConsumedQuota := int(float64(preConsumedTokens) * ratio) + userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { + return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + + // map model name + modelMapping := c.GetString("model_mapping") + if modelMapping != "" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[audioModel] != "" { + audioModel = modelMap[audioModel] + } + } + + baseURL := common.ChannelBaseURLs[channelType] + requestURL := c.Request.URL.String() + + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + requestBody := c.Request.Body + + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + + resp, err := httpClient.Do(req) + if err != nil { + return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + err = req.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + err = c.Request.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + var audioResponse AudioResponse + + defer func() { + go func() { + quota := countTokenText(audioResponse.Text, audioModel) + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + }() + }() + + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &audioResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + return nil +} diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go new file mode 100644 index 0000000000000000000000000000000000000000..39f31a9a90c5767517d57d48ab9610a5d42e9574 --- /dev/null +++ b/controller/relay-baidu.go @@ -0,0 +1,370 @@ +package controller + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strings" + "sync" + "time" +) + +// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 + +type BaiduTokenResponse struct { + ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` +} + +type BaiduMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type BaiduChatRequest struct { + Messages []BaiduMessage `json:"messages"` + Stream bool `json:"stream"` + UserId string `json:"user_id,omitempty"` +} + +type BaiduError struct { + ErrorCode int `json:"error_code"` + ErrorMsg string `json:"error_msg"` +} + +type BaiduChatResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Result string `json:"result"` + IsTruncated bool `json:"is_truncated"` + NeedClearHistory bool `json:"need_clear_history"` + Usage Usage `json:"usage"` + BaiduError +} + +type BaiduChatStreamResponse struct { + BaiduChatResponse + SentenceId int `json:"sentence_id"` + IsEnd bool `json:"is_end"` +} + +type BaiduEmbeddingRequest struct { + Input []string `json:"input"` +} + +type BaiduEmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type BaiduEmbeddingResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []BaiduEmbeddingData `json:"data"` + Usage Usage `json:"usage"` + BaiduError +} + +type BaiduAccessToken struct { + AccessToken string `json:"access_token"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` + ExpiresIn int64 `json:"expires_in,omitempty"` + ExpiresAt time.Time `json:"-"` +} + +var baiduTokenStore sync.Map + +func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { + messages := make([]BaiduMessage, 0, len(request.Messages)) + for _, message := range request.Messages { + if message.Role == "system" { + messages = append(messages, BaiduMessage{ + Role: "user", + Content: message.Content, + }) + messages = append(messages, BaiduMessage{ + Role: "assistant", + Content: "Okay", + }) + } else { + messages = append(messages, BaiduMessage{ + Role: message.Role, + Content: message.Content, + }) + } + } + return &BaiduChatRequest{ + Messages: messages, + Stream: request.Stream, + } +} + +func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: response.Result, + }, + FinishReason: "stop", + } + fullTextResponse := OpenAITextResponse{ + Id: response.Id, + Object: "chat.completion", + Created: response.Created, + Choices: []OpenAITextResponseChoice{choice}, + Usage: response.Usage, + } + return &fullTextResponse +} + +func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = baiduResponse.Result + if baiduResponse.IsEnd { + choice.FinishReason = &stopFinishReason + } + response := ChatCompletionsStreamResponse{ + Id: baiduResponse.Id, + Object: "chat.completion.chunk", + Created: baiduResponse.Created, + Model: "ernie-bot", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest { + baiduEmbeddingRequest := BaiduEmbeddingRequest{ + Input: nil, + } + switch request.Input.(type) { + case string: + baiduEmbeddingRequest.Input = []string{request.Input.(string)} + case []any: + for _, item := range request.Input.([]any) { + if str, ok := item.(string); ok { + baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str) + } + } + } + return &baiduEmbeddingRequest +} + +func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse { + openAIEmbeddingResponse := OpenAIEmbeddingResponse{ + Object: "list", + Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)), + Model: "baidu-embedding", + Usage: response.Usage, + } + for _, item := range response.Data { + openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{ + Object: item.Object, + Index: item.Index, + Embedding: item.Embedding, + }) + } + return &openAIEmbeddingResponse +} + +func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + data = data[6:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var baiduResponse BaiduChatStreamResponse + err := json.Unmarshal([]byte(data), &baiduResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if baiduResponse.Usage.TotalTokens != 0 { + usage.TotalTokens = baiduResponse.Usage.TotalTokens + usage.PromptTokens = baiduResponse.Usage.PromptTokens + usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens + } + response := streamResponseBaidu2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var baiduResponse BaiduChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &baiduResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if baiduResponse.ErrorMsg != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: baiduResponse.ErrorMsg, + Type: "baidu_error", + Param: "", + Code: baiduResponse.ErrorCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseBaidu2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var baiduResponse BaiduEmbeddingResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &baiduResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if baiduResponse.ErrorMsg != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: baiduResponse.ErrorMsg, + Type: "baidu_error", + Param: "", + Code: baiduResponse.ErrorCode, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func getBaiduAccessToken(apiKey string) (string, error) { + if val, ok := baiduTokenStore.Load(apiKey); ok { + var accessToken BaiduAccessToken + if accessToken, ok = val.(BaiduAccessToken); ok { + // soon this will expire + if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) { + go func() { + _, _ = getBaiduAccessTokenHelper(apiKey) + }() + } + return accessToken.AccessToken, nil + } + } + accessToken, err := getBaiduAccessTokenHelper(apiKey) + if err != nil { + return "", err + } + if accessToken == nil { + return "", errors.New("getBaiduAccessToken return a nil token") + } + return (*accessToken).AccessToken, nil +} + +func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { + parts := strings.Split(apiKey, "|") + if len(parts) != 2 { + return nil, errors.New("invalid baidu apikey") + } + req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s", + parts[0], parts[1]), nil) + if err != nil { + return nil, err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Accept", "application/json") + res, err := impatientHTTPClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var accessToken BaiduAccessToken + err = json.NewDecoder(res.Body).Decode(&accessToken) + if err != nil { + return nil, err + } + if accessToken.Error != "" { + return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription) + } + if accessToken.AccessToken == "" { + return nil, errors.New("getBaiduAccessTokenHelper get empty access token") + } + accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second) + baiduTokenStore.Store(apiKey, accessToken) + return &accessToken, nil +} diff --git a/controller/relay-claude.go b/controller/relay-claude.go new file mode 100644 index 0000000000000000000000000000000000000000..1f4a3e7b117f8d7134463684668744d032edb71a --- /dev/null +++ b/controller/relay-claude.go @@ -0,0 +1,220 @@ +package controller + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strings" +) + +type ClaudeMetadata struct { + UserId string `json:"user_id"` +} + +type ClaudeRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + MaxTokensToSample int `json:"max_tokens_to_sample"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + //ClaudeMetadata `json:"metadata,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type ClaudeError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type ClaudeResponse struct { + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Error ClaudeError `json:"error"` +} + +func stopReasonClaude2OpenAI(reason string) string { + switch reason { + case "stop_sequence": + return "stop" + case "max_tokens": + return "length" + default: + return reason + } +} + +func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest { + claudeRequest := ClaudeRequest{ + Model: textRequest.Model, + Prompt: "", + MaxTokensToSample: textRequest.MaxTokens, + StopSequences: nil, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + Stream: textRequest.Stream, + } + if claudeRequest.MaxTokensToSample == 0 { + claudeRequest.MaxTokensToSample = 1000000 + } + prompt := "" + for _, message := range textRequest.Messages { + if message.Role == "user" { + prompt += fmt.Sprintf("\n\nHuman: %s", message.Content) + } else if message.Role == "assistant" { + prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) + } else if message.Role == "system" { + prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) + } + } + prompt += "\n\nAssistant:" + claudeRequest.Prompt = prompt + return &claudeRequest +} + +func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = claudeResponse.Completion + finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = claudeResponse.Model + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: strings.TrimPrefix(claudeResponse.Completion, " "), + Name: nil, + }, + FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + } + fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + } + return &fullTextResponse +} + +func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 { + return i + 4, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if !strings.HasPrefix(data, "event: completion") { + continue + } + data = strings.TrimPrefix(data, "event: completion\r\ndata: ") + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + var claudeResponse ClaudeResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + responseText += claudeResponse.Completion + response := streamResponseClaude2OpenAI(&claudeResponse) + response.Id = responseId + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var claudeResponse ClaudeResponse + err = json.Unmarshal(responseBody, &claudeResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if claudeResponse.Error.Type != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: claudeResponse.Error.Message, + Type: claudeResponse.Error.Type, + Param: "", + Code: claudeResponse.Error.Type, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseClaude2OpenAI(&claudeResponse) + completionTokens := countTokenText(claudeResponse.Completion, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-image.go b/controller/relay-image.go new file mode 100644 index 0000000000000000000000000000000000000000..de6232884008d8062ae3c761f7d9367b91c0a9ef --- /dev/null +++ b/controller/relay-image.go @@ -0,0 +1,180 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" +) + +func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { + imageModel := "dall-e" + + tokenId := c.GetInt("token_id") + channelType := c.GetInt("channel") + userId := c.GetInt("id") + consumeQuota := c.GetBool("consume_quota") + group := c.GetString("group") + + var imageRequest ImageRequest + if consumeQuota { + err := common.UnmarshalBodyReusable(c, &imageRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + } + + // Prompt validation + if imageRequest.Prompt == "" { + return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + } + + // Not "256x256", "512x512", or "1024x1024" + if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { + return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) + } + + // N should between 1 and 10 + if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { + return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + } + + // map model name + modelMapping := c.GetString("model_mapping") + isModelMapped := false + if modelMapping != "" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[imageModel] != "" { + imageModel = modelMap[imageModel] + isModelMapped = true + } + } + + baseURL := common.ChannelBaseURLs[channelType] + requestURL := c.Request.URL.String() + + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + var requestBody io.Reader + if isModelMapped { + jsonStr, err := json.Marshal(imageRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + + modelRatio := common.GetModelRatio(imageModel) + groupRatio := common.GetGroupRatio(group) + ratio := modelRatio * groupRatio + userQuota, err := model.CacheGetUserQuota(userId) + + sizeRatio := 1.0 + // Size + if imageRequest.Size == "256x256" { + sizeRatio = 1 + } else if imageRequest.Size == "512x512" { + sizeRatio = 1.125 + } else if imageRequest.Size == "1024x1024" { + sizeRatio = 1.25 + } + quota := int(ratio*sizeRatio*1000) * imageRequest.N + + if consumeQuota && userQuota-quota < 0 { + return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) + } + + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + + resp, err := httpClient.Do(req) + if err != nil { + return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + err = req.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + err = c.Request.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + var textResponse ImageResponse + + defer func() { + if consumeQuota { + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + } + }() + + if consumeQuota { + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + } + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + return nil +} diff --git a/controller/relay-openai.go b/controller/relay-openai.go new file mode 100644 index 0000000000000000000000000000000000000000..6bdfbc0823a53372e6a1d6337fc997a5d6971f1e --- /dev/null +++ b/controller/relay-openai.go @@ -0,0 +1,144 @@ +package controller + +import ( + "bufio" + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strings" +) + +func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + if data[:6] != "data: " && data[:6] != "[DONE]" { + continue + } + dataChan <- data + data = data[6:] + if !strings.HasPrefix(data, "[DONE]") { + switch relayMode { + case RelayModeChatCompletions: + var streamResponse ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + continue // just ignore the error + } + for _, choice := range streamResponse.Choices { + responseText += choice.Delta.Content + } + case RelayModeCompletions: + var streamResponse CompletionsStreamResponse + err := json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + continue + } + for _, choice := range streamResponse.Choices { + responseText += choice.Text + } + } + } + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + if strings.HasPrefix(data, "data: [DONE]") { + data = data[:12] + } + // some implementations may add \r at the end of data + data = strings.TrimSuffix(data, "\r") + c.Render(-1, common.CustomEvent{Data: data}) + return true + case <-stopChan: + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + var textResponse TextResponse + if consumeQuota { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if textResponse.Error.Type != "" { + return &OpenAIErrorWithStatusCode{ + OpenAIError: textResponse.Error, + StatusCode: resp.StatusCode, + }, nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + } + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err := io.Copy(c.Writer, resp.Body) + if err != nil { + return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + if textResponse.Usage.TotalTokens == 0 { + completionTokens := 0 + for _, choice := range textResponse.Choices { + completionTokens += countTokenText(choice.Message.Content, model) + } + textResponse.Usage = Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + } + return nil, &textResponse.Usage +} diff --git a/controller/relay-palm.go b/controller/relay-palm.go new file mode 100644 index 0000000000000000000000000000000000000000..a705b318b6e8101dfbe851b116a9fadb0588c1b5 --- /dev/null +++ b/controller/relay-palm.go @@ -0,0 +1,205 @@ +package controller + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" +) + +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body +// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body + +type PaLMChatMessage struct { + Author string `json:"author"` + Content string `json:"content"` +} + +type PaLMFilter struct { + Reason string `json:"reason"` + Message string `json:"message"` +} + +type PaLMPrompt struct { + Messages []PaLMChatMessage `json:"messages"` +} + +type PaLMChatRequest struct { + Prompt PaLMPrompt `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK int `json:"topK,omitempty"` +} + +type PaLMError struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` +} + +type PaLMChatResponse struct { + Candidates []PaLMChatMessage `json:"candidates"` + Messages []Message `json:"messages"` + Filters []PaLMFilter `json:"filters"` + Error PaLMError `json:"error"` +} + +func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { + palmRequest := PaLMChatRequest{ + Prompt: PaLMPrompt{ + Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), + }, + Temperature: textRequest.Temperature, + CandidateCount: textRequest.N, + TopP: textRequest.TopP, + TopK: textRequest.MaxTokens, + } + for _, message := range textRequest.Messages { + palmMessage := PaLMChatMessage{ + Content: message.Content, + } + if message.Role == "user" { + palmMessage.Author = "0" + } else { + palmMessage.Author = "1" + } + palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) + } + return &palmRequest +} + +func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: "assistant", + Content: candidate.Content, + }, + FinishReason: "stop", + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + if len(palmResponse.Candidates) > 0 { + choice.Delta.Content = palmResponse.Candidates[0].Content + } + choice.FinishReason = &stopFinishReason + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "palm2" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + stopChan <- true + return + } + err = resp.Body.Close() + if err != nil { + common.SysError("error closing stream response: " + err.Error()) + stopChan <- true + return + } + var palmResponse PaLMChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + stopChan <- true + return + } + fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(palmResponse.Candidates) > 0 { + responseText = palmResponse.Candidates[0].Content + } + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + stopChan <- true + return + } + dataChan <- string(jsonResponse) + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + c.Render(-1, common.CustomEvent{Data: "data: " + data}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var palmResponse PaLMChatResponse + err = json.Unmarshal(responseBody, &palmResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: palmResponse.Error.Message, + Type: palmResponse.Error.Status, + Param: "", + Code: palmResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responsePaLM2OpenAI(&palmResponse) + completionTokens := countTokenText(palmResponse.Candidates[0].Content, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go new file mode 100644 index 0000000000000000000000000000000000000000..624b9d01c941a2043952dd7f105fd64f9a7cccd0 --- /dev/null +++ b/controller/relay-text.go @@ -0,0 +1,522 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/model" + "strings" + "time" +) + +const ( + APITypeOpenAI = iota + APITypeClaude + APITypePaLM + APITypeBaidu + APITypeZhipu + APITypeAli + APITypeXunfei +) + +var httpClient *http.Client +var impatientHTTPClient *http.Client + +func init() { + httpClient = &http.Client{} + impatientHTTPClient = &http.Client{ + Timeout: 5 * time.Second, + } +} + +func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { + channelType := c.GetInt("channel") + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + consumeQuota := c.GetBool("consume_quota") + group := c.GetString("group") + var textRequest GeneralOpenAIRequest + if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { + err := common.UnmarshalBodyReusable(c, &textRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + } + if relayMode == RelayModeModerations && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + if relayMode == RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } + // request validation + if textRequest.Model == "" { + return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) + } + switch relayMode { + case RelayModeCompletions: + if textRequest.Prompt == "" { + return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) + } + case RelayModeChatCompletions: + if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) + } + case RelayModeEmbeddings: + case RelayModeModerations: + if textRequest.Input == "" { + return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) + } + case RelayModeEdits: + if textRequest.Instruction == "" { + return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) + } + } + // map model name + modelMapping := c.GetString("model_mapping") + isModelMapped := false + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[textRequest.Model] != "" { + textRequest.Model = modelMap[textRequest.Model] + isModelMapped = true + } + } + apiType := APITypeOpenAI + switch channelType { + case common.ChannelTypeAnthropic: + apiType = APITypeClaude + case common.ChannelTypeBaidu: + apiType = APITypeBaidu + case common.ChannelTypePaLM: + apiType = APITypePaLM + case common.ChannelTypeZhipu: + apiType = APITypeZhipu + case common.ChannelTypeAli: + apiType = APITypeAli + case common.ChannelTypeXunfei: + apiType = APITypeXunfei + } + baseURL := common.ChannelBaseURLs[channelType] + requestURL := c.Request.URL.String() + if c.GetString("base_url") != "" { + baseURL = c.GetString("base_url") + } + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + switch apiType { + case APITypeOpenAI: + if channelType == common.ChannelTypeAzure { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + requestURL := strings.Split(requestURL, "?")[0] + requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) + baseURL = c.GetString("base_url") + task := strings.TrimPrefix(requestURL, "/v1/") + model_ := textRequest.Model + model_ = strings.Replace(model_, ".", "", -1) + // https://github.com/songquanpeng/one-api/issues/67 + model_ = strings.TrimSuffix(model_, "-0301") + model_ = strings.TrimSuffix(model_, "-0314") + model_ = strings.TrimSuffix(model_, "-0613") + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) + } + case APITypeClaude: + fullRequestURL = "https://api.anthropic.com/v1/complete" + if baseURL != "" { + fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) + } + case APITypeBaidu: + switch textRequest.Model { + case "ERNIE-Bot": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" + case "ERNIE-Bot-turbo": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" + case "BLOOMZ-7B": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" + case "Embedding-V1": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" + } + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + var err error + if apiKey, err = getBaiduAccessToken(apiKey); err != nil { + return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) + } + fullRequestURL += "?access_token=" + apiKey + case APITypePaLM: + fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" + if baseURL != "" { + fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) + } + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?key=" + apiKey + case APITypeZhipu: + method := "invoke" + if textRequest.Stream { + method = "sse-invoke" + } + fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) + case APITypeAli: + fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + } + var promptTokens int + var completionTokens int + switch relayMode { + case RelayModeChatCompletions: + promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) + case RelayModeCompletions: + promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) + case RelayModeModerations: + promptTokens = countTokenInput(textRequest.Input, textRequest.Model) + } + preConsumedTokens := common.PreConsumedQuota + if textRequest.MaxTokens != 0 { + preConsumedTokens = promptTokens + textRequest.MaxTokens + } + modelRatio := common.GetModelRatio(textRequest.Model) + groupRatio := common.GetGroupRatio(group) + ratio := modelRatio * groupRatio + preConsumedQuota := int(float64(preConsumedTokens) * ratio) + userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { + return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if consumeQuota && preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + var requestBody io.Reader + if isModelMapped { + jsonStr, err := json.Marshal(textRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + switch apiType { + case APITypeClaude: + claudeRequest := requestOpenAI2Claude(textRequest) + jsonStr, err := json.Marshal(claudeRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + case APITypeBaidu: + var jsonData []byte + var err error + switch relayMode { + case RelayModeEmbeddings: + baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) + jsonData, err = json.Marshal(baiduEmbeddingRequest) + default: + baiduRequest := requestOpenAI2Baidu(textRequest) + jsonData, err = json.Marshal(baiduRequest) + } + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonData) + case APITypePaLM: + palmRequest := requestOpenAI2PaLM(textRequest) + jsonStr, err := json.Marshal(palmRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + case APITypeZhipu: + zhipuRequest := requestOpenAI2Zhipu(textRequest) + jsonStr, err := json.Marshal(zhipuRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + case APITypeAli: + aliRequest := requestOpenAI2Ali(textRequest) + jsonStr, err := json.Marshal(aliRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } + + var req *http.Request + var resp *http.Response + isStream := textRequest.Stream + + if apiType != APITypeXunfei { // cause xunfei use websocket + req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + switch apiType { + case APITypeOpenAI: + if channelType == common.ChannelTypeAzure { + req.Header.Set("api-key", apiKey) + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + if channelType == common.ChannelTypeOpenRouter { + req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") + req.Header.Set("X-Title", "One API") + } + } + case APITypeClaude: + req.Header.Set("x-api-key", apiKey) + anthropicVersion := c.Request.Header.Get("anthropic-version") + if anthropicVersion == "" { + anthropicVersion = "2023-06-01" + } + req.Header.Set("anthropic-version", anthropicVersion) + case APITypeZhipu: + token := getZhipuToken(apiKey) + req.Header.Set("Authorization", token) + case APITypeAli: + req.Header.Set("Authorization", "Bearer "+apiKey) + if textRequest.Stream { + req.Header.Set("X-DashScope-SSE", "enable") + } + } + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + //req.Header.Set("Connection", c.Request.Header.Get("Connection")) + resp, err = httpClient.Do(req) + if err != nil { + return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + err = req.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + err = c.Request.Body.Close() + if err != nil { + return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) + } + isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + + if resp.StatusCode != http.StatusOK { + return relayErrorHandler(resp) + } + } + + var textResponse TextResponse + tokenName := c.GetString("token_name") + channelId := c.GetInt("channel_id") + + defer func() { + // c.Writer.Flush() + go func() { + if consumeQuota { + quota := 0 + completionRatio := common.GetCompletionRatio(textRequest.Model) + promptTokens = textResponse.Usage.PromptTokens + completionTokens = textResponse.Usage.CompletionTokens + + quota = promptTokens + int(float64(completionTokens)*completionRatio) + quota = int(float64(quota) * ratio) + if ratio != 0 && quota <= 0 { + quota = 1 + } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + + model.UpdateChannelUsedQuota(channelId, quota) + } + } + }() + }() + switch apiType { + case APITypeOpenAI: + if isStream { + err, responseText := openaiStreamHandler(c, resp, relayMode) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } + case APITypeClaude: + if isStream { + err, responseText := claudeStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } + case APITypeBaidu: + if isStream { + err, usage := baiduStreamHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } else { + var err *OpenAIErrorWithStatusCode + var usage *Usage + switch relayMode { + case RelayModeEmbeddings: + err, usage = baiduEmbeddingHandler(c, resp) + default: + err, usage = baiduHandler(c, resp) + } + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } + case APITypePaLM: + if textRequest.Stream { // PaLM2 API does not support stream + err, responseText := palmStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } + case APITypeZhipu: + if isStream { + err, usage := zhipuStreamHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + // zhipu's API does not return prompt tokens & completion tokens + textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens + return nil + } else { + err, usage := zhipuHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + // zhipu's API does not return prompt tokens & completion tokens + textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens + return nil + } + case APITypeAli: + if isStream { + err, usage := aliStreamHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } else { + err, usage := aliHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } + case APITypeXunfei: + if isStream { + auth := c.Request.Header.Get("Authorization") + auth = strings.TrimPrefix(auth, "Bearer ") + splits := strings.Split(auth, "|") + if len(splits) != 3 { + return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) + } + err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } else { + return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest) + } + default: + return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) + } +} diff --git a/controller/relay-utils.go b/controller/relay-utils.go new file mode 100644 index 0000000000000000000000000000000000000000..9010d2757c2a326b7585ff84f41818d559117cbd --- /dev/null +++ b/controller/relay-utils.go @@ -0,0 +1,169 @@ +package controller + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/pkoukk/tiktoken-go" + "io" + "net/http" + "one-api/common" + "strconv" +) + +var stopFinishReason = "stop" + +var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} + +func InitTokenEncoders() { + common.SysLog("initializing token encoders") + fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) + } + for model, _ := range common.ModelRatio { + tokenEncoder, err := tiktoken.EncodingForModel(model) + if err != nil { + common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) + tokenEncoderMap[model] = fallbackTokenEncoder + continue + } + tokenEncoderMap[model] = tokenEncoder + } + common.SysLog("token encoders initialized") +} + +func getTokenEncoder(model string) *tiktoken.Tiktoken { + if tokenEncoder, ok := tokenEncoderMap[model]; ok { + return tokenEncoder + } + tokenEncoder, err := tiktoken.EncodingForModel(model) + if err != nil { + common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) + tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error())) + } + } + tokenEncoderMap[model] = tokenEncoder + return tokenEncoder +} + +func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { + if common.ApproximateTokenEnabled { + return int(float64(len(text)) * 0.38) + } + return len(tokenEncoder.Encode(text, nil, nil)) +} + +func countTokenMessages(messages []Message, model string) int { + tokenEncoder := getTokenEncoder(model) + // Reference: + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // https://github.com/pkoukk/tiktoken-go/issues/6 + // + // Every message follows <|start|>{role/name}\n{content}<|end|>\n + var tokensPerMessage int + var tokensPerName int + if model == "gpt-3.5-turbo-0301" { + tokensPerMessage = 4 + tokensPerName = -1 // If there's a name, the role is omitted + } else { + tokensPerMessage = 3 + tokensPerName = 1 + } + tokenNum := 0 + for _, message := range messages { + tokenNum += tokensPerMessage + tokenNum += getTokenNum(tokenEncoder, message.Content) + tokenNum += getTokenNum(tokenEncoder, message.Role) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } + } + tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> + return tokenNum +} + +func countTokenInput(input any, model string) int { + switch input.(type) { + case string: + return countTokenText(input.(string), model) + case []string: + text := "" + for _, s := range input.([]string) { + text += s + } + return countTokenText(text, model) + } + return 0 +} + +func countTokenText(text string, model string) int { + tokenEncoder := getTokenEncoder(model) + return getTokenNum(tokenEncoder, text) +} + +func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { + openAIError := OpenAIError{ + Message: err.Error(), + Type: "one_api_error", + Code: code, + } + return &OpenAIErrorWithStatusCode{ + OpenAIError: openAIError, + StatusCode: statusCode, + } +} + +func shouldDisableChannel(err *OpenAIError, statusCode int) bool { + if !common.AutomaticDisableChannelEnabled { + return false + } + if err == nil { + return false + } + if statusCode == http.StatusUnauthorized { + return true + } + if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + return false +} + +func setEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} + +func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { + openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ + StatusCode: resp.StatusCode, + OpenAIError: OpenAIError{ + Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Type: "one_api_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var textResponse TextResponse + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return + } + openAIErrorWithStatusCode.OpenAIError = textResponse.Error + return +} diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go new file mode 100644 index 0000000000000000000000000000000000000000..3b6fe5a010ec6fa6459be8abbc67082b66d0b443 --- /dev/null +++ b/controller/relay-xunfei.go @@ -0,0 +1,290 @@ +package controller + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "io" + "net/http" + "net/url" + "one-api/common" + "strings" + "time" +) + +// https://console.xfyun.cn/services/cbm +// https://www.xfyun.cn/doc/spark/Web.html + +type XunfeiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type XunfeiChatRequest struct { + Header struct { + AppId string `json:"app_id"` + } `json:"header"` + Parameter struct { + Chat struct { + Domain string `json:"domain,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Auditing bool `json:"auditing,omitempty"` + } `json:"chat"` + } `json:"parameter"` + Payload struct { + Message struct { + Text []XunfeiMessage `json:"text"` + } `json:"message"` + } `json:"payload"` +} + +type XunfeiChatResponseTextItem struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` +} + +type XunfeiChatResponse struct { + Header struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Status int `json:"status"` + } `json:"header"` + Payload struct { + Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []XunfeiChatResponseTextItem `json:"text"` + } `json:"choices"` + Usage struct { + //Text struct { + // QuestionTokens string `json:"question_tokens"` + // PromptTokens string `json:"prompt_tokens"` + // CompletionTokens string `json:"completion_tokens"` + // TotalTokens string `json:"total_tokens"` + //} `json:"text"` + Text Usage `json:"text"` + } `json:"usage"` + } `json:"payload"` +} + +func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest { + messages := make([]XunfeiMessage, 0, len(request.Messages)) + for _, message := range request.Messages { + if message.Role == "system" { + messages = append(messages, XunfeiMessage{ + Role: "user", + Content: message.Content, + }) + messages = append(messages, XunfeiMessage{ + Role: "assistant", + Content: "Okay", + }) + } else { + messages = append(messages, XunfeiMessage{ + Role: message.Role, + Content: message.Content, + }) + } + } + xunfeiRequest := XunfeiChatRequest{} + xunfeiRequest.Header.AppId = xunfeiAppId + xunfeiRequest.Parameter.Chat.Domain = domain + xunfeiRequest.Parameter.Chat.Temperature = request.Temperature + xunfeiRequest.Parameter.Chat.TopK = request.N + xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens + xunfeiRequest.Payload.Message.Text = messages + return &xunfeiRequest +} + +func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { + if len(response.Payload.Choices.Text) == 0 { + response.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: response.Payload.Choices.Text[0].Content, + }, + } + fullTextResponse := OpenAITextResponse{ + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: []OpenAITextResponseChoice{choice}, + Usage: response.Payload.Usage.Text, + } + return &fullTextResponse +} + +func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse { + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content + if xunfeiResponse.Payload.Choices.Status == 2 { + choice.FinishReason = &stopFinishReason + } + response := ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "SparkDesk", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { + HmacWithShaToBase64 := func(algorithm, data, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(data)) + encodeData := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(encodeData) + } + ul, err := url.Parse(hostUrl) + if err != nil { + fmt.Println(err) + } + date := time.Now().UTC().Format(time.RFC1123) + signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} + sign := strings.Join(signString, "\n") + sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret) + authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, + "hmac-sha256", "host date request-line", sha) + authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) + v := url.Values{} + v.Add("host", ul.Host) + v.Add("date", date) + v.Add("authorization", authorization) + callUrl := hostUrl + "?" + v.Encode() + return callUrl +} + +func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) { + var usage Usage + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + if apiVersion == "" { + apiVersion = "v1.1" + common.SysLog("api_version not found, use default: " + apiVersion) + } + domain := "general" + if apiVersion == "v2.1" { + domain = "generalv2" + } + hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion) + d := websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil) + if err != nil || resp.StatusCode != 101 { + return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil + } + data := requestOpenAI2Xunfei(textRequest, appId, domain) + err = conn.WriteJSON(data) + if err != nil { + return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil + } + dataChan := make(chan XunfeiChatResponse) + stopChan := make(chan bool) + go func() { + for { + _, msg, err := conn.ReadMessage() + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + break + } + var response XunfeiChatResponse + err = json.Unmarshal(msg, &response) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + break + } + dataChan <- response + if response.Payload.Choices.Status == 2 { + err := conn.Close() + if err != nil { + common.SysError("error closing websocket connection: " + err.Error()) + } + break + } + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case xunfeiResponse := <-dataChan: + usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens + usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens + usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens + response := streamResponseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + return nil, &usage +} + +func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var xunfeiResponse XunfeiChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &xunfeiResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if xunfeiResponse.Header.Code != 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: xunfeiResponse.Header.Message, + Type: "xunfei_error", + Param: "", + Code: xunfeiResponse.Header.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go new file mode 100644 index 0000000000000000000000000000000000000000..7a4a582dfe39eaf083c21deba3d26c25ae51c9ba --- /dev/null +++ b/controller/relay-zhipu.go @@ -0,0 +1,301 @@ +package controller + +import ( + "bufio" + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" + "io" + "net/http" + "one-api/common" + "strings" + "sync" + "time" +) + +// https://open.bigmodel.cn/doc/api#chatglm_std +// chatglm_std, chatglm_lite +// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke +// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke + +type ZhipuMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ZhipuRequest struct { + Prompt []ZhipuMessage `json:"prompt"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + RequestId string `json:"request_id,omitempty"` + Incremental bool `json:"incremental,omitempty"` +} + +type ZhipuResponseData struct { + TaskId string `json:"task_id"` + RequestId string `json:"request_id"` + TaskStatus string `json:"task_status"` + Choices []ZhipuMessage `json:"choices"` + Usage `json:"usage"` +} + +type ZhipuResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Success bool `json:"success"` + Data ZhipuResponseData `json:"data"` +} + +type ZhipuStreamMetaResponse struct { + RequestId string `json:"request_id"` + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + Usage `json:"usage"` +} + +type zhipuTokenData struct { + Token string + ExpiryTime time.Time +} + +var zhipuTokens sync.Map +var expSeconds int64 = 24 * 3600 + +func getZhipuToken(apikey string) string { + data, ok := zhipuTokens.Load(apikey) + if ok { + tokenData := data.(zhipuTokenData) + if time.Now().Before(tokenData.ExpiryTime) { + return tokenData.Token + } + } + + split := strings.Split(apikey, ".") + if len(split) != 2 { + common.SysError("invalid zhipu key: " + apikey) + return "" + } + + id := split[0] + secret := split[1] + + expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 + expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) + + timestamp := time.Now().UnixNano() / 1e6 + + payload := jwt.MapClaims{ + "api_key": id, + "exp": expMillis, + "timestamp": timestamp, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) + + token.Header["alg"] = "HS256" + token.Header["sign_type"] = "SIGN" + + tokenString, err := token.SignedString([]byte(secret)) + if err != nil { + return "" + } + + zhipuTokens.Store(apikey, zhipuTokenData{ + Token: tokenString, + ExpiryTime: expiryTime, + }) + + return tokenString +} + +func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { + messages := make([]ZhipuMessage, 0, len(request.Messages)) + for _, message := range request.Messages { + if message.Role == "system" { + messages = append(messages, ZhipuMessage{ + Role: "system", + Content: message.Content, + }) + messages = append(messages, ZhipuMessage{ + Role: "user", + Content: "Okay", + }) + } else { + messages = append(messages, ZhipuMessage{ + Role: message.Role, + Content: message.Content, + }) + } + } + return &ZhipuRequest{ + Prompt: messages, + Temperature: request.Temperature, + TopP: request.TopP, + Incremental: false, + } +} + +func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Id: response.Data.TaskId, + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)), + Usage: response.Data.Usage, + } + for i, choice := range response.Data.Choices { + openaiChoice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: choice.Role, + Content: strings.Trim(choice.Content, "\""), + }, + FinishReason: "", + } + if i == len(response.Data.Choices)-1 { + openaiChoice.FinishReason = "stop" + } + fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) + } + return &fullTextResponse +} + +func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = zhipuResponse + response := ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "chatglm", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = "" + choice.FinishReason = &stopFinishReason + response := ChatCompletionsStreamResponse{ + Id: zhipuResponse.RequestId, + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "chatglm", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + return &response, &zhipuResponse.Usage +} + +func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var usage *Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { + return i + 2, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + metaChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + lines := strings.Split(data, "\n") + for i, line := range lines { + if len(line) < 5 { + continue + } + if line[:5] == "data:" { + dataChan <- line[5:] + if i != len(lines)-1 { + dataChan <- "\n" + } + } else if line[:5] == "meta:" { + metaChan <- line[5:] + } + } + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + response := streamResponseZhipu2OpenAI(data) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case data := <-metaChan: + var zhipuResponse ZhipuStreamMetaResponse + err := json.Unmarshal([]byte(data), &zhipuResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + usage = zhipuUsage + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, usage +} + +func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var zhipuResponse ZhipuResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &zhipuResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if !zhipuResponse.Success { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: zhipuResponse.Msg, + Type: "zhipu_error", + Param: "", + Code: zhipuResponse.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/controller/relay.go b/controller/relay.go new file mode 100644 index 0000000000000000000000000000000000000000..056d42d38fac1474eb2a62a8ccddf6f1c99a6f16 --- /dev/null +++ b/controller/relay.go @@ -0,0 +1,228 @@ +package controller + +import ( + "fmt" + "net/http" + "one-api/common" + "strconv" + "strings" + + "github.com/gin-gonic/gin" +) + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + Name *string `json:"name,omitempty"` +} + +const ( + RelayModeUnknown = iota + RelayModeChatCompletions + RelayModeCompletions + RelayModeEmbeddings + RelayModeModerations + RelayModeImagesGenerations + RelayModeEdits + RelayModeAudio +) + +// https://platform.openai.com/docs/api-reference/chat + +type GeneralOpenAIRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` +} + +type TextRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` + //Stream bool `json:"stream"` +} + +type ImageRequest struct { + Prompt string `json:"prompt"` + N int `json:"n"` + Size string `json:"size"` +} + +type AudioResponse struct { + Text string `json:"text,omitempty"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type OpenAIError struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code any `json:"code"` +} + +type OpenAIErrorWithStatusCode struct { + OpenAIError + StatusCode int `json:"status_code"` +} + +type TextResponse struct { + Choices []OpenAITextResponseChoice `json:"choices"` + Usage `json:"usage"` + Error OpenAIError `json:"error"` +} + +type OpenAITextResponseChoice struct { + Index int `json:"index"` + Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type OpenAITextResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Choices []OpenAITextResponseChoice `json:"choices"` + Usage `json:"usage"` +} + +type OpenAIEmbeddingResponseItem struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type OpenAIEmbeddingResponse struct { + Object string `json:"object"` + Data []OpenAIEmbeddingResponseItem `json:"data"` + Model string `json:"model"` + Usage `json:"usage"` +} + +type ImageResponse struct { + Created int `json:"created"` + Data []struct { + Url string `json:"url"` + } +} + +type ChatCompletionsStreamResponseChoice struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +type ChatCompletionsStreamResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionsStreamResponseChoice `json:"choices"` +} + +type CompletionsStreamResponse struct { + Choices []struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +func Relay(c *gin.Context) { + relayMode := RelayModeUnknown + if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { + relayMode = RelayModeChatCompletions + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { + relayMode = RelayModeCompletions + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { + relayMode = RelayModeEmbeddings + } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + relayMode = RelayModeEmbeddings + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + relayMode = RelayModeModerations + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + relayMode = RelayModeImagesGenerations + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { + relayMode = RelayModeEdits + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + relayMode = RelayModeAudio + } + var err *OpenAIErrorWithStatusCode + switch relayMode { + case RelayModeImagesGenerations: + err = relayImageHelper(c, relayMode) + case RelayModeAudio: + err = relayAudioHelper(c, relayMode) + default: + err = relayTextHelper(c, relayMode) + } + if err != nil { + retryTimesStr := c.Query("retry") + retryTimes, _ := strconv.Atoi(retryTimesStr) + if retryTimesStr == "" { + retryTimes = common.RetryTimes + } + if retryTimes > 0 { + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) + } else { + if err.StatusCode == http.StatusTooManyRequests { + err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" + } + c.JSON(err.StatusCode, gin.H{ + "error": err.OpenAIError, + }) + } + channelId := c.GetInt("channel_id") + common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + // https://platform.openai.com/docs/guides/error-codes/api-errors + if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { + channelId := c.GetInt("channel_id") + channelName := c.GetString("channel_name") + disableChannel(channelId, channelName, err.Message) + } + } +} + +func RelayNotImplemented(c *gin.Context) { + err := OpenAIError{ + Message: "API not implemented", + Type: "one_api_error", + Param: "", + Code: "api_not_implemented", + } + c.JSON(http.StatusNotImplemented, gin.H{ + "error": err, + }) +} + +func RelayNotFound(c *gin.Context) { + err := OpenAIError{ + Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), + Type: "invalid_request_error", + Param: "", + Code: "", + } + c.JSON(http.StatusNotFound, gin.H{ + "error": err, + }) +} diff --git a/controller/token.go b/controller/token.go new file mode 100644 index 0000000000000000000000000000000000000000..8642122ca38e9a4bc52dc1fc46ac532afc9be5ef --- /dev/null +++ b/controller/token.go @@ -0,0 +1,228 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/model" + "strconv" +) + +func GetAllTokens(c *gin.Context) { + userId := c.GetInt("id") + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": tokens, + }) + return +} + +func SearchTokens(c *gin.Context) { + userId := c.GetInt("id") + keyword := c.Query("keyword") + tokens, err := model.SearchUserTokens(userId, keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": tokens, + }) + return +} + +func GetToken(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + userId := c.GetInt("id") + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + token, err := model.GetTokenByIds(id, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": token, + }) + return +} + +func GetTokenStatus(c *gin.Context) { + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + token, err := model.GetTokenByIds(tokenId, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + expiredAt := token.ExpiredTime + if expiredAt == -1 { + expiredAt = 0 + } + c.JSON(http.StatusOK, gin.H{ + "object": "credit_summary", + "total_granted": token.RemainQuota, + "total_used": 0, // not supported currently + "total_available": token.RemainQuota, + "expires_at": expiredAt * 1000, + }) +} + +func AddToken(c *gin.Context) { + token := model.Token{} + err := c.ShouldBindJSON(&token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if len(token.Name) > 30 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "令牌名称过长", + }) + return + } + cleanToken := model.Token{ + UserId: c.GetInt("id"), + Name: token.Name, + Key: common.GenerateKey(), + CreatedTime: common.GetTimestamp(), + AccessedTime: common.GetTimestamp(), + ExpiredTime: token.ExpiredTime, + RemainQuota: token.RemainQuota, + UnlimitedQuota: token.UnlimitedQuota, + } + err = cleanToken.Insert() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func DeleteToken(c *gin.Context) { + id, _ := strconv.Atoi(c.Param("id")) + userId := c.GetInt("id") + err := model.DeleteTokenById(id, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func UpdateToken(c *gin.Context) { + userId := c.GetInt("id") + statusOnly := c.Query("status_only") + token := model.Token{} + err := c.ShouldBindJSON(&token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if len(token.Name) > 30 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "令牌名称过长", + }) + return + } + cleanToken, err := model.GetTokenByIds(token.Id, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if token.Status == common.TokenStatusEnabled { + if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", + }) + return + } + if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", + }) + return + } + } + if statusOnly != "" { + cleanToken.Status = token.Status + } else { + // If you add more fields, please also update token.Update() + cleanToken.Name = token.Name + cleanToken.ExpiredTime = token.ExpiredTime + cleanToken.RemainQuota = token.RemainQuota + cleanToken.UnlimitedQuota = token.UnlimitedQuota + } + err = cleanToken.Update() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": cleanToken, + }) + return +} diff --git a/controller/user.go b/controller/user.go new file mode 100644 index 0000000000000000000000000000000000000000..8fd10b82774aaf07e45aaa6b6af84469e7d0a8fe --- /dev/null +++ b/controller/user.go @@ -0,0 +1,743 @@ +package controller + +import ( + "encoding/json" + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +type LoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +func Login(c *gin.Context) { + if !common.PasswordLoginEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员关闭了密码登录", + "success": false, + }) + return + } + var loginRequest LoginRequest + err := json.NewDecoder(c.Request.Body).Decode(&loginRequest) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": "无效的参数", + "success": false, + }) + return + } + username := loginRequest.Username + password := loginRequest.Password + if username == "" || password == "" { + c.JSON(http.StatusOK, gin.H{ + "message": "无效的参数", + "success": false, + }) + return + } + user := model.User{ + Username: username, + Password: password, + } + err = user.ValidateAndFill() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": err.Error(), + "success": false, + }) + return + } + setupLogin(&user, c) +} + +// setup session & cookies and then return user info +func setupLogin(user *model.User, c *gin.Context) { + session := sessions.Default(c) + session.Set("id", user.Id) + session.Set("username", user.Username) + session.Set("role", user.Role) + session.Set("status", user.Status) + err := session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": "无法保存会话信息,请重试", + "success": false, + }) + return + } + cleanUser := model.User{ + Id: user.Id, + Username: user.Username, + DisplayName: user.DisplayName, + Role: user.Role, + Status: user.Status, + } + c.JSON(http.StatusOK, gin.H{ + "message": "", + "success": true, + "data": cleanUser, + }) +} + +func Logout(c *gin.Context) { + session := sessions.Default(c) + session.Clear() + err := session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": err.Error(), + "success": false, + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "message": "", + "success": true, + }) +} + +func Register(c *gin.Context) { + if !common.RegisterEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员关闭了新用户注册", + "success": false, + }) + return + } + if !common.PasswordRegisterEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", + "success": false, + }) + return + } + var user model.User + err := json.NewDecoder(c.Request.Body).Decode(&user) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if err := common.Validate.Struct(&user); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "输入不合法 " + err.Error(), + }) + return + } + if common.EmailVerificationEnabled { + if user.Email == "" || user.VerificationCode == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员开启了邮箱验证,请输入邮箱地址和验证码", + }) + return + } + if !common.VerifyCodeWithKey(user.Email, user.VerificationCode, common.EmailVerificationPurpose) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码错误或已过期", + }) + return + } + } + affCode := user.AffCode // this code is the inviter's code, not the user's own code + inviterId, _ := model.GetUserIdByAffCode(affCode) + cleanUser := model.User{ + Username: user.Username, + Password: user.Password, + DisplayName: user.Username, + InviterId: inviterId, + } + if common.EmailVerificationEnabled { + cleanUser.Email = user.Email + } + if err := cleanUser.Insert(inviterId); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func GetAllUsers(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": users, + }) + return +} + +func SearchUsers(c *gin.Context) { + keyword := c.Query("keyword") + users, err := model.SearchUsers(keyword) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": users, + }) + return +} + +func GetUser(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user, err := model.GetUserById(id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + myRole := c.GetInt("role") + if myRole <= user.Role && myRole != common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权获取同级或更高等级用户的信息", + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": user, + }) + return +} + +func GenerateAccessToken(c *gin.Context) { + id := c.GetInt("id") + user, err := model.GetUserById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.AccessToken = common.GetUUID() + + if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "请重试,系统生成的 UUID 竟然重复了!", + }) + return + } + + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": user.AccessToken, + }) + return +} + +func GetAffCode(c *gin.Context) { + id := c.GetInt("id") + user, err := model.GetUserById(id, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if user.AffCode == "" { + user.AffCode = common.GetRandomString(4) + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": user.AffCode, + }) + return +} + +func GetSelf(c *gin.Context) { + id := c.GetInt("id") + user, err := model.GetUserById(id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": user, + }) + return +} + +func UpdateUser(c *gin.Context) { + var updatedUser model.User + err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) + if err != nil || updatedUser.Id == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if updatedUser.Password == "" { + updatedUser.Password = "$I_LOVE_U" // make Validator happy :) + } + if err := common.Validate.Struct(&updatedUser); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "输入不合法 " + err.Error(), + }) + return + } + originUser, err := model.GetUserById(updatedUser.Id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + myRole := c.GetInt("role") + if myRole <= originUser.Role && myRole != common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权更新同权限等级或更高权限等级的用户信息", + }) + return + } + if myRole <= updatedUser.Role && myRole != common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权将其他用户权限等级提升到大于等于自己的权限等级", + }) + return + } + if updatedUser.Password == "$I_LOVE_U" { + updatedUser.Password = "" // rollback to what it should be + } + updatePassword := updatedUser.Password != "" + if err := updatedUser.Update(updatePassword); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if originUser.Quota != updatedUser.Quota { + model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func UpdateSelf(c *gin.Context) { + var user model.User + err := json.NewDecoder(c.Request.Body).Decode(&user) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if user.Password == "" { + user.Password = "$I_LOVE_U" // make Validator happy :) + } + if err := common.Validate.Struct(&user); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "输入不合法 " + err.Error(), + }) + return + } + + cleanUser := model.User{ + Id: c.GetInt("id"), + Username: user.Username, + Password: user.Password, + DisplayName: user.DisplayName, + } + if user.Password == "$I_LOVE_U" { + user.Password = "" // rollback to what it should be + cleanUser.Password = "" + } + updatePassword := user.Password != "" + if err := cleanUser.Update(updatePassword); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func DeleteUser(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + originUser, err := model.GetUserById(id, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + myRole := c.GetInt("role") + if myRole <= originUser.Role { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权删除同权限等级或更高权限等级的用户", + }) + return + } + err = model.DeleteUserById(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return + } +} + +func DeleteSelf(c *gin.Context) { + id := c.GetInt("id") + user, _ := model.GetUserById(id, false) + + if user.Role == common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "不能删除超级管理员账户", + }) + return + } + + err := model.DeleteUserById(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +func CreateUser(c *gin.Context) { + var user model.User + err := json.NewDecoder(c.Request.Body).Decode(&user) + if err != nil || user.Username == "" || user.Password == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if err := common.Validate.Struct(&user); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "输入不合法 " + err.Error(), + }) + return + } + if user.DisplayName == "" { + user.DisplayName = user.Username + } + myRole := c.GetInt("role") + if user.Role >= myRole { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法创建权限大于等于自己的用户", + }) + return + } + // Even for admin users, we cannot fully trust them! + cleanUser := model.User{ + Username: user.Username, + Password: user.Password, + DisplayName: user.DisplayName, + } + if err := cleanUser.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +type ManageRequest struct { + Username string `json:"username"` + Action string `json:"action"` +} + +// ManageUser Only admin user can do this +func ManageUser(c *gin.Context) { + var req ManageRequest + err := json.NewDecoder(c.Request.Body).Decode(&req) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + user := model.User{ + Username: req.Username, + } + // Fill attributes + model.DB.Where(&user).First(&user) + if user.Id == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户不存在", + }) + return + } + myRole := c.GetInt("role") + if myRole <= user.Role && myRole != common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权更新同权限等级或更高权限等级的用户信息", + }) + return + } + switch req.Action { + case "disable": + user.Status = common.UserStatusDisabled + if user.Role == common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法禁用超级管理员用户", + }) + return + } + case "enable": + user.Status = common.UserStatusEnabled + case "delete": + if user.Role == common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法删除超级管理员用户", + }) + return + } + if err := user.Delete(); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + case "promote": + if myRole != common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "普通管理员用户无法提升其他用户为管理员", + }) + return + } + if user.Role >= common.RoleAdminUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该用户已经是管理员", + }) + return + } + user.Role = common.RoleAdminUser + case "demote": + if user.Role == common.RoleRootUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法降级超级管理员用户", + }) + return + } + if user.Role == common.RoleCommonUser { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该用户已经是普通用户", + }) + return + } + user.Role = common.RoleCommonUser + } + + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + clearUser := model.User{ + Role: user.Role, + Status: user.Status, + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": clearUser, + }) + return +} + +func EmailBind(c *gin.Context) { + email := c.Query("email") + code := c.Query("code") + if !common.VerifyCodeWithKey(email, code, common.EmailVerificationPurpose) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "验证码错误或已过期", + }) + return + } + id := c.GetInt("id") + user := model.User{ + Id: id, + } + err := user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.Email = email + // no need to check if this email already taken, because we have used verification code to check it + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if user.Role == common.RoleRootUser { + common.RootUserEmail = email + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + +type topUpRequest struct { + Key string `json:"key"` +} + +func TopUp(c *gin.Context) { + req := topUpRequest{} + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + id := c.GetInt("id") + quota, err := model.Redeem(req.Key, id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": quota, + }) + return +} diff --git a/controller/wechat.go b/controller/wechat.go new file mode 100644 index 0000000000000000000000000000000000000000..ff4c9fb6c42cc1730ef21c2af0570c5ab05ea73a --- /dev/null +++ b/controller/wechat.go @@ -0,0 +1,164 @@ +package controller + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "time" +) + +type wechatLoginResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + Data string `json:"data"` +} + +func getWeChatIdByCode(code string) (string, error) { + if code == "" { + return "", errors.New("无效的参数") + } + req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", common.WeChatServerToken) + client := http.Client{ + Timeout: 5 * time.Second, + } + httpResponse, err := client.Do(req) + if err != nil { + return "", err + } + defer httpResponse.Body.Close() + var res wechatLoginResponse + err = json.NewDecoder(httpResponse.Body).Decode(&res) + if err != nil { + return "", err + } + if !res.Success { + return "", errors.New(res.Message) + } + if res.Data == "" { + return "", errors.New("验证码错误或已过期") + } + return res.Data, nil +} + +func WeChatAuth(c *gin.Context) { + if !common.WeChatAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员未开启通过微信登录以及注册", + "success": false, + }) + return + } + code := c.Query("code") + wechatId, err := getWeChatIdByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": err.Error(), + "success": false, + }) + return + } + user := model.User{ + WeChatId: wechatId, + } + if model.IsWeChatIdAlreadyTaken(wechatId) { + err := user.FillUserByWeChatId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if common.RegisterEnabled { + user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) + user.DisplayName = "WeChat User" + user.Role = common.RoleCommonUser + user.Status = common.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != common.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + setupLogin(&user, c) +} + +func WeChatBind(c *gin.Context) { + if !common.WeChatAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "管理员未开启通过微信登录以及注册", + "success": false, + }) + return + } + code := c.Query("code") + wechatId, err := getWeChatIdByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": err.Error(), + "success": false, + }) + return + } + if model.IsWeChatIdAlreadyTaken(wechatId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该微信账号已被绑定", + }) + return + } + id := c.GetInt("id") + user := model.User{ + Id: id, + } + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.WeChatId = wechatId + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..003122bbe19af71a794519eae62528acdfc1394f --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,34 @@ +version: '3.4' + +services: + one-api: + image: justsong/one-api:latest + container_name: one-api + restart: always + command: --log-dir /app/logs + ports: + - "3000:3000" + volumes: + - ./data:/data + - ./logs:/app/logs + environment: + - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 + - REDIS_CONN_STRING=redis://redis + - SESSION_SECRET=random_string # 修改为随机字符串 + - TZ=Asia/Shanghai +# - NODE_TYPE=slave # 多机部署时从节点取消注释该行 +# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行 +# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行 + + depends_on: + - redis + healthcheck: + test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ] + interval: 30s + timeout: 10s + retries: 3 + + redis: + image: redis:latest + container_name: redis + restart: always diff --git a/go.mod b/go.mod new file mode 100644 index 0000000000000000000000000000000000000000..79b01f93d9a99c844b4bfec5a4a01e59206eadec --- /dev/null +++ b/go.mod @@ -0,0 +1,61 @@ +module one-api + +// +heroku goVersion go1.18 +go 1.18 + +require ( + github.com/gin-contrib/cors v1.4.0 + github.com/gin-contrib/gzip v0.0.6 + github.com/gin-contrib/sessions v0.0.5 + github.com/gin-contrib/static v0.0.1 + github.com/gin-gonic/gin v1.9.1 + github.com/go-playground/validator/v10 v10.14.0 + github.com/go-redis/redis/v8 v8.11.5 + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/google/uuid v1.3.0 + github.com/gorilla/websocket v1.5.0 + github.com/pkoukk/tiktoken-go v0.1.5 + golang.org/x/crypto v0.9.0 + gorm.io/driver/mysql v1.4.3 + gorm.io/driver/sqlite v1.4.3 + gorm.io/gorm v1.25.0 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-sql-driver/mysql v1.6.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/gorilla/context v1.1.1 // indirect + github.com/gorilla/securecookie v1.1.1 // indirect + github.com/gorilla/sessions v1.2.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.3.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/driver/postgres v1.5.2 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000000000000000000000000000000000000..810e7819c948bc360c4b568d9e018ef59ba4171e --- /dev/null +++ b/go.sum @@ -0,0 +1,205 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= +github.com/gin-contrib/cors v1.4.0/go.mod h1:bs9pNM0x/UsmHPBWT2xZz9ROh8xYjYkiURUfmBoMlcs= +github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4= +github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk= +github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfmF+0jE= +github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-contrib/static v0.0.1 h1:JVxuvHPuUfkoul12N7dtQw7KRn/pSMq7Ue1Va9Swm1U= +github.com/gin-contrib/static v0.0.1/go.mod h1:CSxeF+wep05e0kCOsqWdAWbSszmc31zTIbD8TvWl7Hs= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= +github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= +github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= +github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= +github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= +github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4= +github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= +gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= +gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= +gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= +gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= +gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= +gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= +gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= +gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= +gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= +gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/i18n/en.json b/i18n/en.json new file mode 100644 index 0000000000000000000000000000000000000000..aed659792f364a1efe543698456dda2fcffbb3eb --- /dev/null +++ b/i18n/en.json @@ -0,0 +1,527 @@ +{ + "$%.6f 额度": "$%.6f quota", + "%d 点额度": "%d point quota", + "尚未实现": "Not yet implemented", + "余额不足": "Insufficient balance", + "危险操作": "Hazardous operations", + "输入你的账户名": "Enter your account name", + "确认删除": "Confirm Delete", + "确认绑定": "Confirm Binding", + "您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.", + "\"通道「%s」(#%d)已被禁用\"": "\"Channel %s (#%d) has been disabled\"", + "通道「%s」(#%d)已被禁用,原因:%s": "Channel %s (#%d) has been disabled, reason: %s", + "测试已在运行中": "Test is already running", + "响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs", + "通道测试完成": "Channel test completed", + "通道测试完成,如果没有收到禁用通知,说明所有通道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal", + "无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!", + "返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!", + "管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub", + "管理员关闭了新用户注册": "The administrator has turned off new user registration", + "用户已被封禁": "User has been banned", + "该 GitHub 账户已被绑定": "The GitHub account has been bound", + "邮箱地址已被占用": "Email address is occupied", + "%s邮箱验证邮件": "%s Email verification email", + "

您好,你正在进行%s邮箱验证。

": "

Hello, you are verifying %s email.

", + "

您的验证码为: %s

": "

Your verification code is: %s

", + "

验证码 %d 分钟内有效,如果不是本人操作,请忽略。

": "

The verification code is valid within %d minutes. If it is not your operation, please ignore it.

", + "无效的参数": "Invalid parameter", + "该邮箱地址未注册": "The email address is not registered", + "%s密码重置": "%s Password reset", + "

您好,你正在进行%s密码重置。

": "

Hello, you are resetting %s password.

", + "

点击此处进行密码重置。

": "

Click here to reset your password.

", + "

重置链接 %d 分钟内有效,如果不是本人操作,请忽略。

": "

The reset link is valid within %d minutes. If it is not your operation, please ignore it.

", + "重置链接非法或已过期": "Reset link is illegal or expired", + "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!": "Unable to enable GitHub OAuth, please fill in GitHub Client ID and GitHub Client Secret first!", + "无法启用微信登录,请先填入微信登录相关配置信息!": "Unable to enable WeChat login, please fill in the relevant configuration information for WeChat login first!", + "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!": "Unable to enable Turnstile verification, please fill in the relevant configuration information for Turnstile verification first!", + "兑换码名称长度必须在1-20之间": "The length of the redemption code name must be between 1-20", + "兑换码个数必须大于0": "The number of redemption codes must be greater than 0", + "一次兑换码批量生成的个数不能大于 100": "The number of redemption codes generated in a batch cannot be greater than 100", + "通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)": "Using model %s with token %s consumes %s (model rate %.2f, group rate %.2f)", + "当前分组上游负载已饱和,请稍后再试": "The current group load is saturated, please try again later", + "令牌名称过长": "Token name is too long", + "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期": "The token has expired and cannot be enabled. Please modify the expiration time of the token, or set it to never expire.", + "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度": "The available quota of the token has been used up and cannot be enabled. Please modify the remaining quota of the token, or set it to unlimited quota", + "管理员关闭了密码登录": "The administrator has turned off password login", + "无法保存会话信息,请重试": "Unable to save session information, please try again", + "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册": "The administrator has turned off registration via password. Please use the form of third-party account verification to register", + "输入不合法 ": "Input is illegal ", + "管理员开启了邮箱验证,请输入邮箱地址和验证码": "The administrator has turned on email verification, please enter the email address and verification code", + "验证码错误或已过期": "Verification code error or expired", + "无权获取同级或更高等级用户的信息": "No permission to get information of users at the same level or higher", + "请重试,系统生成的 UUID 竟然重复了!": "Please try again, the system-generated UUID is actually duplicated!", + "输入不合法": "Input is illegal", + "无权更新同权限等级或更高权限等级的用户信息": "No permission to update user information with the same permission level or higher permission level", + "管理员将用户额度从 %s修改为 %s": "The administrator changed the user quota from %s to %s", + "无权删除同权限等级或更高权限等级的用户": "No permission to delete users with the same permission level or higher permission level", + "无法创建权限大于等于自己的用户": "Unable to create users with permissions greater than or equal to your own", + "用户不存在": "User does not exist", + "无法禁用超级管理员用户": "Unable to disable super administrator user", + "无法删除超级管理员用户": "Unable to delete super administrator user", + "普通管理员用户无法提升其他用户为管理员": "Ordinary administrator users cannot promote other users to administrators", + "该用户已经是管理员": "The user is already an administrator", + "无法降级超级管理员用户": "Unable to downgrade super administrator user", + "该用户已经是普通用户": "The user is already an ordinary user", + "管理员未开启通过微信登录以及注册": "The administrator has not enabled login and registration via WeChat", + "该微信账号已被绑定": "The WeChat account has been bound", + "无权进行此操作,未登录且未提供 access token": "No permission to perform this operation, not logged in and no access token provided", + "无权进行此操作,access token 无效": "No permission to perform this operation, access token is invalid", + "无权进行此操作,权限不足": "No permission to perform this operation, insufficient permissions", + "普通用户不支持指定渠道": "Ordinary users do not support specifying channels", + "无效的渠道 ID": "Invalid channel ID", + "该渠道已被禁用": "The channel has been disabled", + "无效的请求": "Invalid request", + "无可用渠道": "No available channels", + "Turnstile token 为空": "Turnstile token is empty", + "Turnstile 校验失败,请刷新重试!": "Turnstile verification failed, please refresh and try again!", + "id 为空!": "id is empty!", + "未提供兑换码": "No redemption code provided", + "无效的 user id": "Invalid user id", + "无效的兑换码": "Invalid redemption code", + "该兑换码已被使用": "The redemption code has been used", + "通过兑换码充值 %s": "Recharge %s through redemption code", + "未提供令牌": "No token provided", + "该令牌状态不可用": "The token status is not available", + "该令牌已过期": "The token has expired", + "该令牌额度已用尽": "The token quota has been used up", + "无效的令牌": "Invalid token", + "id 或 userId 为空!": "id or userId is empty!", + "quota 不能为负数!": "quota cannot be negative!", + "令牌额度不足": "Insufficient token quota", + "用户额度不足": "Insufficient user quota", + "您的额度即将用尽": "Your quota is about to run out", + "您的额度已用尽": "Your quota has been used up", + "%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s": "%s, the current remaining quota is %d, in order not to affect your use, please recharge in time.
Recharge link: %s", + "affCode 为空!": "affCode is empty!", + "新用户注册赠送 %s": "New user registration gives %s", + "使用邀请码赠送 %s": "Use invitation code to give %s", + "邀请用户赠送 %s": "Invite users to give %s", + "用户名或密码为空": "Username or password is empty", + "用户名或密码错误,或用户已被封禁": "Username or password is wrong, or user has been banned", + "email 为空!": "email is empty!", + "GitHub id 为空!": "GitHub id is empty!", + "WeChat id 为空!": "WeChat id is empty!", + "username 为空!": "username is empty!", + "邮箱地址或密码为空!": "Email address or password is empty!", + "OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用": "OpenAI interface aggregation management, supports multiple channels including Azure, can be used for secondary distribution management key, only single executable file, Docker image has been packaged, one-click deployment, out of the box", + "未知类型": "Unknown type", + "不支持": "Not supported", + "操作成功完成!": "Operation completed successfully!", + "已启用": "Enabled", + "已禁用": "Disabled", + "未知状态": "Unknown status", + " 秒": "s", + " 分钟 ": " m ", + " 小时 ": " h ", + " 天 ": " d ", + " 个月 ": " M ", + " 年 ": " y ", + "未测试": "Not tested", + "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", + "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", + "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", + "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", + "搜索渠道的 ID,名称和密钥 ...": "Search for channel ID, name and key ...", + "名称": "Name", + "分组": "Group", + "类型": "Type", + "状态": "Status", + "响应时间": "Response time", + "余额": "Balance", + "操作": "Operation", + "未更新": "Not updated", + "测试": "Test", + "更新余额": "Update balance", + "删除": "Delete", + "删除渠道 {channel.name}": "Delete channel {channel.name}", + "禁用": "Disable", + "启用": "Enable", + "编辑": "Edit", + "添加新的渠道": "Add a new channel", + "测试所有已启用通道": "Test all enabled channels", + "更新所有已启用通道余额": "Update the balance of all enabled channels", + "刷新": "Refresh", + "处理中...": "Processing...", + "绑定成功!": "Binding succeeded!", + "登录成功!": "Login succeeded!", + "操作失败,重定向至登录界面中...": "Operation failed, redirecting to the login page...", + "出现错误,第 ${count} 次重试中...": "An error occurred, retrying for the ${count} time...", + "首页": "Home", + "渠道": "Channel", + "令牌": "Token", + "兑换": "Redeem", + "充值": "Recharge", + "用户": "User", + "日志": "Log", + "设置": "Settings", + "关于": "About", + "聊天": "Chat", + "注销成功!": "Logout succeeded!", + "注销": "Logout", + "登录": "Login", + "注册": "Register", + "加载{name}中...": "Loading {name}...", + "未登录或登录已过期,请重新登录!": "Not logged in or login has expired, please log in again!", + "用户登录": "User login", + "\"用户名\"": "\"Username\"", + "\"密码\"": "\"Password\"", + "忘记密码?": "Forget password?", + "点击重置": "Click to reset", + "; 没有账户?": "; No account?", + "点击注册": "Click to register", + "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "Scan the QR code of WeChat to follow the official account, enter \"verification code\" to get the verification code (valid within three minutes)", + "\"验证码\"": "\"Verification code\"", + "全部用户": "All users", + "当前用户": "Current user", + "'全部'": "'All'", + "'充值'": "'Recharge'", + "'消费'": "'Consumption'", + "'管理'": "'Management'", + "'系统'": "'System'", + " 充值 ": " Recharge ", + " 消费 ": " Consumption ", + " 管理 ": " Management ", + " 系统 ": " System ", + " 未知 ": " Unknown ", + "时间": "Time", + "详情": "Details", + "选择模式": "Select mode", + "选择明细分类": "Select details category", + "模型倍率不是合法的 JSON 字符串": "Model rate is not a valid JSON string", + "分组倍率不是合法的 JSON 字符串": "Group rate is not a valid JSON string", + "通用设置": "General Settings", + "充值链接": "Recharge Link", + "例如发卡网站的购买链接": "For example, the purchase link of the card issuing website", + "聊天页面链接": "Chat Page Link", + "例如 ChatGPT Next Web 的部署地址": "For example, the deployment address of ChatGPT Next Web", + "单位美元额度": "Unit Dollar Quota", + "一单位货币能兑换的额度": "Quota that can be exchanged for one unit of currency", + "启用额度消费日志记录": "Enable quota consumption log recording", + "以货币形式显示额度": "Display quota in the form of currency", + "相关 API 显示令牌额度而非用户额度": "Related API displays token quota instead of user quota", + "保存通用设置": "Save General Settings", + "监控设置": "Monitoring Settings", + "最长响应时间": "Longest Response Time", + "单位秒": "Unit in seconds", + "当运行通道全部测试时": "When all operating channels are tested", + "超过此时间将自动禁用通道": "Channels will be automatically disabled if this time is exceeded", + "额度提醒阈值": "Quota reminder threshold", + "低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this", + "失败时自动禁用通道": "Automatically disable the channel when it fails", + "保存监控设置": "Save Monitoring Settings", + "额度设置": "Quota Settings", + "新用户初始额度": "Initial quota for new users", + "例如": "For example", + "请求预扣费额度": "Request for pre-deducted quota", + "请求结束后多退少补": "Refund more or less after the request ends", + "邀请新用户奖励额度": "Invite new users to reward quota", + "新用户使用邀请码奖励额度": "New user rewards quota using invitation code", + "保存额度设置": "Save Quota Settings", + "倍率设置": "Rate Settings", + "模型倍率": "Model rate", + "为一个 JSON 文本": "Is a JSON text", + "键为模型名称": "Key is model name", + "值为倍率": "Value is the rate", + "分组倍率": "Group rate", + "键为分组名称": "Key is group name", + "保存倍率设置": "Save Rate Settings", + "已是最新版本": "Is the latest version", + "检查更新": "Check for updates", + "公告": "Announcement", + "在此输入新的公告内容,支持 Markdown & HTML 代码": "Enter the new announcement content here, supports Markdown & HTML code", + "保存公告": "Save Announcement", + "个性化设置": "Personalization Settings", + "系统名称": "System Name", + "在此输入系统名称": "Enter the system name here", + "设置系统名称": "Set system name", + "图片地址": "Image URL", + "在此输入 Logo 图片地址": "Enter the Logo image URL here", + "首页内容": "Home Page Content", + "在此输入首页内容,支持 Markdown & HTML 代码,设置后首页的状态信息将不再显示。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为首页": "Enter the homepage content here, supports Markdown & HTML code. Once set, the status information of the homepage will not be displayed. If a link is entered, it will be used as the src attribute of the iframe, allowing you to set any webpage as the homepage.", + "保存首页内容": "Save Home Page Content", + "在此输入新的关于内容,支持 Markdown & HTML 代码。如果输入的是一个链接,则会使用该链接作为 iframe 的 src 属性,这允许你设置任意网页作为关于页面": "Enter new about content here, supports Markdown & HTML code. If a link is entered, it will be used as the src attribute of the iframe, allowing you to set any webpage as the about page.", + "保存关于": "Save About", + "移除 One API 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目": "Removal of One API copyright mark must first be authorized. Project maintenance requires a lot of effort. If this project is meaningful to you, please actively support it.", + "页脚": "Footer", + "在此输入新的页脚,留空则使用默认页脚,支持 HTML 代码": "Enter the new footer here, leave blank to use the default footer, supports HTML code.", + "设置页脚": "Set Footer", + "新版本": "New Version", + "关闭": "Close", + "密码已重置并已复制到剪贴板": "Password has been reset and copied to clipboard", + "密码重置确认": "Password Reset Confirmation", + "邮箱地址": "Email Address", + "提交": "Submit", + "请稍后几秒重试": "Please retry in a few seconds", + "正在检查用户环境": "Checking user environment", + "重置邮件发送成功": "Reset mail sent successfully", + "请检查邮箱": "Please check your email", + "密码重置": "Password Reset", + "令牌已重置并已复制到剪贴板": "Token has been reset and copied to clipboard", + "邀请链接已复制到剪切板": "Invitation link has been copied to clipboard", + "微信账户绑定成功": "WeChat account binding succeeded", + "验证码发送成功": "Verification code sent successfully", + "邮箱账户绑定成功": "Email account binding succeeded", + "注意": "Note", + "此处生成的令牌用于系统管理": "The token generated here is used for system management", + "而非用于请求 OpenAI 相关的服务": "Not for requesting OpenAI related services", + "请知悉": "Please be aware", + "更新个人信息": "Update Personal Information", + "生成系统访问令牌": "Generate System Access Token", + "复制邀请链接": "Copy Invitation Link", + "账号绑定": "Account Binding", + "绑定微信账号": "Bind WeChat Account", + "微信扫码关注公众号": "Scan the QR code with WeChat to follow the official account", + "输入": "Enter", + "验证码": "Verification Code", + "获取验证码": "Get Verification Code", + "三分钟内有效": "Valid for three minutes", + "绑定": "Bind", + "绑定 GitHub 账号": "Bind GitHub Account", + "绑定邮箱地址": "Bind Email Address", + "输入邮箱地址": "Enter Email Address", + "未使用": "Unused", + "已使用": "Used", + "操作成功完成": "Operation successfully completed", + "搜索兑换码的 ID 和名称": "Search for ID and name", + "额度": "Quota", + "创建时间": "Creation Time", + "兑换时间": "Redemption Time", + "尚未兑换": "Not yet redeemed", + "已复制到剪贴板": "Copied to clipboard", + "无法复制到剪贴板": "Unable to copy to clipboard", + "请手动复制": "Please copy manually", + "已将兑换码填入搜索框": "The voucher code has been filled into the search box", + "复制": "Copy", + "添加新的兑换码": "Add a new voucher", + "密码长度不得小于 8 位": "Password length must not be less than 8 characters", + "两次输入的密码不一致": "The two passwords entered do not match", + "注册成功": "Registration succeeded", + "请稍后几秒重试,Turnstile 正在检查用户环境": "Please retry in a few seconds, Turnstile is checking user environment", + "验证码发送成功,请检查你的邮箱": "Verification code sent successfully, please check your email", + "新用户注册": "New User Registration", + "输入用户名,最长 12 位": "Enter username, up to 12 characters", + "输入密码,最短 8 位,最长 20 位": "Enter password, at least 8 characters and up to 20 characters", + "输入验证码": "Enter Verification Code", + "已有账户": "Already have an account", + "点击登录": "Click to log in", + "服务器地址": "Server Address", + "更新服务器地址": "Update Server Address", + "配置登录注册": "Configure Login/Registration", + "允许通过密码进行登录": "Allow login via password", + "允许通过密码进行注册": "Allow registration via password", + "通过密码注册时需要进行邮箱验证": "Email verification is required when registering via password", + "允许通过 GitHub 账户登录 & 注册": "Allow login & registration via GitHub account", + "允许通过微信登录 & 注册": "Allow login & registration via WeChat", + "允许新用户注册(此项为否时,新用户将无法以任何方式进行注册": "Allow new user registration (if this option is off, new users will not be able to register in any way", + "启用 Turnstile 用户校验": "Enable Turnstile user verification", + "配置 SMTP": "Configure SMTP", + "用以支持系统的邮件发送": "To support the system email sending", + "SMTP 服务器地址": "SMTP Server Address", + "例如:smtp.qq.com": "For example: smtp.qq.com", + "SMTP 端口": "SMTP Port", + "默认: 587": "Default: 587", + "SMTP 账户": "SMTP Account", + "通常是邮箱地址": "Usually an email address", + "发送者邮箱": "Sender email", + "通常和邮箱地址保持一致": "Usually consistent with the email address", + "SMTP 访问凭证": "SMTP Access Credential", + "敏感信息不会发送到前端显示": "Sensitive information will not be displayed in the frontend", + "保存 SMTP 设置": "Save SMTP Settings", + "配置 GitHub OAuth App": "Configure GitHub OAuth App", + "用以支持通过 GitHub 进行登录注册": "To support login & registration via GitHub", + "点击此处": "Click here", + "管理你的 GitHub OAuth App": "Manage your GitHub OAuth App", + "输入你注册的 GitHub OAuth APP 的 ID": "Enter your registered GitHub OAuth APP ID", + "保存 GitHub OAuth 设置": "Save GitHub OAuth Settings", + "配置 WeChat Server": "Configure WeChat Server", + "用以支持通过微信进行登录注册": "To support login & registration via WeChat", + "了解 WeChat Server": "Learn about WeChat Server", + "WeChat Server 访问凭证": "WeChat Server Access Credential", + "微信公众号二维码图片链接": "WeChat Public Account QR Code Image Link", + "输入一个图片链接": "Enter an image link", + "保存 WeChat Server 设置": "Save WeChat Server Settings", + "配置 Turnstile": "Configure Turnstile", + "用以支持用户校验": "To support user verification", + "管理你的 Turnstile Sites,推荐选择 Invisible Widget Type": "Manage your Turnstile Sites, recommend selecting Invisible Widget Type", + "输入你注册的 Turnstile Site Key": "Enter your registered Turnstile Site Key", + "保存 Turnstile 设置": "Save Turnstile Settings", + "已过期": "Expired", + "已耗尽": "Exhausted", + "搜索令牌的名称 ...": "Search for the name of the token...", + "已用额度": "Quota used", + "剩余额度": "Remaining quota", + "过期时间": "Expiration time", + "无": "None", + "无限制": "Unlimited", + "永不过期": "Never expires", + "无法复制到剪贴板,请手动复制,已将令牌填入搜索框": "Unable to copy to clipboard, please copy manually, the token has been entered into the search box", + "删除令牌": "Delete Token", + "添加新的令牌": "Add New Token", + "普通用户": "Regular User", + "管理员": "Admin", + "超级管理员": "Super Admin", + "未知身份": "Unknown Identity", + "已激活": "Activated", + "已封禁": "Banned", + "搜索用户的 ID,用户名,显示名称,以及邮箱地址 ...": "Search user ID, username, display name, and email address...", + "用户名": "Username", + "统计信息": "Statistics", + "用户角色": "User Role", + "未绑定邮箱地址": "Email not bound", + "请求次数": "Number of Requests", + "提升": "Promote", + "降级": "Demote", + "删除用户": "Delete User", + "添加新的用户": "Add New User", + "自定义": "Custom", + "等价金额": "Equivalent Amount", + "未登录或登录已过期,请重新登录": "Not logged in or login has expired, please log in again", + "请求次数过多,请稍后再试": "Too many requests, please try again later", + "服务器内部错误,请联系管理员": "Server internal error, please contact the administrator", + "本站仅作演示之用,无服务端": "This site is for demonstration purposes only, no server-side", + "超级管理员未设置充值链接!": "Super administrator has not set the recharge link!", + "错误:": "Error: ", + "新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面": "New version available: ${data.version}, please refresh the page using shortcut Shift + F5", + "无法正常连接至服务器": "Unable to connect to the server normally", + "管理渠道": "Manage Channels", + "系统状况": "System Status", + "系统信息": "System Information", + "系统信息总览": "System Information Overview", + "版本": "Version", + "源码": "Source Code", + "启动时间": "Startup Time", + "系统配置": "System Configuration", + "系统配置总览": "System Configuration Overview", + "邮箱验证": "Email Verification", + "未启用": "Not Enabled", + "GitHub 身份验证": "GitHub Authentication", + "微信身份验证": "WeChat Authentication", + "Turnstile 用户校验": "Turnstile User Verification", + "创建新的渠道": "Create New Channel", + "镜像": "Mirror", + "请输入镜像站地址,格式为:https://domain.com,可不填,不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used", + "模型": "Model", + "请选择该通道所支持的模型": "Please select the model supported by the channel", + "填入基础模型": "Fill in the basic model", + "填入所有模型": "Fill in all models", + "清除所有模型": "Clear all models", + "密钥": "Key", + "请输入密钥": "Please enter the key", + "批量创建": "Batch Create", + "更新渠道信息": "Update Channel Information", + "我的令牌": "My Tokens", + "管理兑换码": "Manage Redeem Codes", + "兑换码": "Redeem Code", + "管理用户": "Manage Users", + "额度明细": "Quota Details", + "个人设置": "Personal Settings", + "运营设置": "Operation Settings", + "系统设置": "System Settings", + "其他设置": "Other Settings", + "项目仓库地址": "Project Repository Address", + "可在设置页面设置关于内容,支持 HTML & Markdown": "You can set the content about in the settings page, support HTML & Markdown", + "由{' '}": "built by{' '}", + "构建,源代码遵循{' '}": ", the source code licensed under{' '}", + "MIT 协议": "MIT License", + "充值额度": "Recharge Quota", + "获取兑换码": "Get Redeem Code", + "一个月后过期": "Expires after one month", + "一天后过期": "Expires after one day", + "一小时后过期": "Expires after one hour", + "一分钟后过期": "Expires after one minute", + "创建新的令牌": "Create New Token", + "注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.", + "设为无限额度": "Set to unlimited quota", + "更新令牌信息": "Update Token Information", + "请输入充值码!": "Please enter the recharge code!", + "请输入名称": "Please enter a name", + "请输入密钥,一行一个": "Please enter the key, one per line", + "请输入额度": "Please enter the quota", + "令牌创建成功": "Token created successfully", + "令牌更新成功": "Token updated successfully", + "充值成功!": "Recharge successful!", + "更新用户信息": "Update User Information", + "请输入新的用户名": "Please enter a new username", + "密码": "Password", + "请输入新的密码": "Please enter a new password", + "显示名称": "Display Name", + "请输入新的显示名称": "Please enter a new display name", + "已绑定的 GitHub 账户": "GitHub Account Bound", + "此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only. Users need to bind through the relevant binding button on the personal settings page, and cannot be modified directly", + "已绑定的微信账户": "WeChat Account Bound", + "已绑定的邮箱账户": "Email Account Bound", + "用户信息更新成功!": "User information updated successfully!", + "模型倍率 %.2f,分组倍率 %.2f": "model rate %.2f, group rate %.2f", + "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", + "用户名称": "User Name", + "令牌名称": "Token Name", + "留空则查询全部用户": "Leave blank to query all users", + "留空则查询全部令牌": "Leave blank to query all tokens", + "模型名称": "Model Name", + "留空则查询全部模型": "Leave blank to query all models", + "起始时间": "Start Time", + "结束时间": "End Time", + "查询": "Query", + "提示": "Prompt", + "补全": "Completion", + "消耗额度": "Used Quota", + "可选值": "Optional Values", + "渠道不存在:%d": "Channel does not exist: %d", + "数据库一致性已被破坏,请联系管理员": "Database consistency has been broken, please contact the administrator", + "使用近似的方式估算 token 数以减少计算量": "Estimate the number of tokens in an approximate way to reduce computational load", + "请填写ChannelName和ChannelKey!": "Please fill in the ChannelName and ChannelKey!", + "请至少选择一个Model!": "Please select at least one Model!", + "加载首页内容失败": "Failed to load the homepage content", + "加载关于内容失败": "Failed to load the About content", + "兑换码更新成功!": "Redemption code updated successfully!", + "兑换码创建成功!": "Redemption code created successfully!", + "用户账户创建成功!": "User account created successfully!", + "生成数量": "Generate quantity", + "请输入生成数量": "Please enter the quantity to generate", + "创建新用户账户": "Create new user account", + "渠道更新成功!": "Channel updated successfully!", + "渠道创建成功!": "Channel created successfully!", + "请选择分组": "Please select a group", + "更新兑换码信息": "Update redemption code information", + "创建新的兑换码": "Create a new redemption code", + "请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit the group ratio in the system settings page to add a new group:", + "未找到所请求的页面": "The requested page was not found", + "过期时间格式错误!": "Expiration time format error!", + "请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制": "Please enter the expiration time, the format is yyyy-MM-dd HH:mm:ss, -1 means no limit", + "此项可选,为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称,例如:": "This is optional, it's a JSON text, the key is the model name requested by the user, and the value is the model name to be replaced, for example:", + "此项可选,输入镜像站地址,格式为:": "This is optional, enter the mirror site address, the format is:", + "模型映射": "Model mapping", + "请输入默认 API 版本,例如:2023-03-15-preview,该配置可以被实际的请求查询参数所覆盖": "Please enter the default API version, for example: 2023-03-15-preview, this configuration can be overridden by the actual request query parameters", + "默认": "Default", + "图片演示": "Image demo", + "参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)", + "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", + "取消无限额度": "Cancel unlimited quota", + "取消": "Cancel", + "请输入新的剩余额度": "Please enter the new remaining quota", + "请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code", + "请输入用户名": "Please enter username", + "请输入显示名称": "Please enter display name", + "请输入密码": "Please enter password", + "模型部署名称必须和模型名称保持一致": "The model deployment name must be consistent with the model name", + ",因为 One API 会把请求体中的 model": ", because One API will take the model in the request body", + "请输入 AZURE_OPENAI_ENDPOINT": "Please enter AZURE_OPENAI_ENDPOINT", + "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", + "Homepage URL 填": "Fill in the Homepage URL", + "Authorization callback URL 填": "Fill in the Authorization callback URL", + "请为通道命名": "Please name the channel", + "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", + "模型重定向": "Model redirection", + "请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel", + "注意,": "Note that, ", + ",图片演示。": "related image demo.", + "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!", + "代理": "Proxy", + "此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com", + "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?", + "按照如下格式输入:": "Enter in the following format:", + "模型版本": "Model version", + "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1", + "点击查看": "click to view" +} diff --git a/i18n/translate.py b/i18n/translate.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba9bc5dc1628fcf7b30180d2e19560e373bd6a8 --- /dev/null +++ b/i18n/translate.py @@ -0,0 +1,61 @@ +import argparse +import json +import os + +def list_file_paths(path): + file_paths = [] + for root, dirs, files in os.walk(path): + if "node_modules" in dirs: + dirs.remove("node_modules") + if "build" in dirs: + dirs.remove("build") + if "i18n" in dirs: + dirs.remove("i18n") + for file in files: + file_path = os.path.join(root, file) + if file_path.endswith("png") or file_path.endswith("ico") or file_path.endswith("db") or file_path.endswith("exe"): + continue + file_paths.append(file_path) + + for dir in dirs: + dir_path = os.path.join(root, dir) + file_paths += list_file_paths(dir_path) + + return file_paths + + +def replace_keys_in_repository(repo_path, json_file_path): + with open(json_file_path, 'r', encoding="utf-8") as json_file: + key_value_pairs = json.load(json_file) + + pairs = [] + for key, value in key_value_pairs.items(): + pairs.append((key, value)) + pairs.sort(key=lambda x: len(x[0]), reverse=True) + + files = list_file_paths(repo_path) + print('Total files: {}'.format(len(files))) + for file_path in files: + replace_keys_in_file(file_path, pairs) + + +def replace_keys_in_file(file_path, pairs): + try: + with open(file_path, 'r', encoding="utf-8") as file: + content = file.read() + + for key, value in pairs: + content = content.replace(key, value) + + with open(file_path, 'w', encoding="utf-8") as file: + file.write(content) + except UnicodeDecodeError: + print('UnicodeDecodeError: {}'.format(file_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Replace keys in repository.') + parser.add_argument('--repository_path', help='Path to repository') + parser.add_argument('--json_file_path', help='Path to JSON file') + args = parser.parse_args() + replace_keys_in_repository(args.repository_path, args.json_file_path) diff --git a/main.go b/main.go new file mode 100644 index 0000000000000000000000000000000000000000..9fb0a73e0fd2a968e1f17bb22828fc9998569437 --- /dev/null +++ b/main.go @@ -0,0 +1,101 @@ +package main + +import ( + "embed" + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/cookie" + "github.com/gin-gonic/gin" + "one-api/common" + "one-api/controller" + "one-api/middleware" + "one-api/model" + "one-api/router" + "os" + "strconv" +) + +//go:embed web/build +var buildFS embed.FS + +//go:embed web/build/index.html +var indexPage []byte + +func main() { + common.SetupGinLog() + common.SysLog("One API " + common.Version + " started") + if os.Getenv("GIN_MODE") != "debug" { + gin.SetMode(gin.ReleaseMode) + } + if common.DebugEnabled { + common.SysLog("running in debug mode") + } + // Initialize SQL Database + err := model.InitDB() + if err != nil { + common.FatalLog("failed to initialize database: " + err.Error()) + } + defer func() { + err := model.CloseDB() + if err != nil { + common.FatalLog("failed to close database: " + err.Error()) + } + }() + + // Initialize Redis + err = common.InitRedisClient() + if err != nil { + common.FatalLog("failed to initialize Redis: " + err.Error()) + } + + // Initialize options + model.InitOptionMap() + if common.RedisEnabled { + model.InitChannelCache() + } + if os.Getenv("SYNC_FREQUENCY") != "" { + frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY")) + if err != nil { + common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) + } + common.SyncFrequency = frequency + go model.SyncOptions(frequency) + if common.RedisEnabled { + go model.SyncChannelCache(frequency) + } + } + if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { + frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) + if err != nil { + common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) + } + go controller.AutomaticallyUpdateChannels(frequency) + } + if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { + frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) + if err != nil { + common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + } + go controller.AutomaticallyTestChannels(frequency) + } + controller.InitTokenEncoders() + + // Initialize HTTP server + server := gin.Default() + // This will cause SSE not to work!!! + //server.Use(gzip.Gzip(gzip.DefaultCompression)) + server.Use(middleware.CORS()) + + // Initialize session store + store := cookie.NewStore([]byte(common.SessionSecret)) + server.Use(sessions.Sessions("session", store)) + + router.SetRouter(server, buildFS, indexPage) + var port = os.Getenv("PORT") + if port == "" { + port = strconv.Itoa(*common.Port) + } + err = server.Run(":" + port) + if err != nil { + common.FatalLog("failed to start HTTP server: " + err.Error()) + } +} diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000000000000000000000000000000000000..060e005c900bfd044fa009c4408bfd62d8dde26e --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,138 @@ +package middleware + +import ( + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/model" + "strings" +) + +func authHelper(c *gin.Context, minRole int) { + session := sessions.Default(c) + username := session.Get("username") + role := session.Get("role") + id := session.Get("id") + status := session.Get("status") + if username == nil { + // Check access token + accessToken := c.Request.Header.Get("Authorization") + if accessToken == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "message": "无权进行此操作,未登录且未提供 access token", + }) + c.Abort() + return + } + user := model.ValidateAccessToken(accessToken) + if user != nil && user.Username != "" { + // Token is valid + username = user.Username + role = user.Role + id = user.Id + status = user.Status + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权进行此操作,access token 无效", + }) + c.Abort() + return + } + } + if status.(int) == common.UserStatusDisabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "用户已被封禁", + }) + c.Abort() + return + } + if role.(int) < minRole { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无权进行此操作,权限不足", + }) + c.Abort() + return + } + c.Set("username", username) + c.Set("role", role) + c.Set("id", id) + c.Next() +} + +func UserAuth() func(c *gin.Context) { + return func(c *gin.Context) { + authHelper(c, common.RoleCommonUser) + } +} + +func AdminAuth() func(c *gin.Context) { + return func(c *gin.Context) { + authHelper(c, common.RoleAdminUser) + } +} + +func RootAuth() func(c *gin.Context) { + return func(c *gin.Context) { + authHelper(c, common.RoleRootUser) + } +} + +func TokenAuth() func(c *gin.Context) { + return func(c *gin.Context) { + key := c.Request.Header.Get("Authorization") + key = strings.TrimPrefix(key, "Bearer ") + key = strings.TrimPrefix(key, "sk-") + parts := strings.Split(key, "-") + key = parts[0] + token, err := model.ValidateUserToken(key) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + c.Abort() + return + } + if !model.CacheIsUserEnabled(token.UserId) { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "message": "用户已被封禁", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + c.Set("id", token.UserId) + c.Set("token_id", token.Id) + c.Set("token_name", token.Name) + requestURL := c.Request.URL.String() + consumeQuota := true + if strings.HasPrefix(requestURL, "/v1/models") { + consumeQuota = false + } + c.Set("consume_quota", consumeQuota) + if len(parts) > 1 { + if model.IsAdmin(token.UserId) { + c.Set("channelId", parts[1]) + } else { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "message": "普通用户不支持指定渠道", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + } + c.Next() + } +} diff --git a/middleware/cache.go b/middleware/cache.go new file mode 100644 index 0000000000000000000000000000000000000000..979734ab2fb2817c1274fc934181001a2d5dd375 --- /dev/null +++ b/middleware/cache.go @@ -0,0 +1,16 @@ +package middleware + +import ( + "github.com/gin-gonic/gin" +) + +func Cache() func(c *gin.Context) { + return func(c *gin.Context) { + if c.Request.RequestURI == "/" { + c.Header("Cache-Control", "no-cache") + } else { + c.Header("Cache-Control", "max-age=604800") // one week + } + c.Next() + } +} diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 0000000000000000000000000000000000000000..d2a109abece64ceb4268dcefa1a5dcb7dcf85e67 --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,15 @@ +package middleware + +import ( + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" +) + +func CORS() gin.HandlerFunc { + config := cors.DefaultConfig() + config.AllowAllOrigins = true + config.AllowCredentials = true + config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"} + config.AllowHeaders = []string{"*"} + return cors.New(config) +} diff --git a/middleware/distributor.go b/middleware/distributor.go new file mode 100644 index 0000000000000000000000000000000000000000..93827c9518e81ac007d8f04a1fe88a7735ebfb1f --- /dev/null +++ b/middleware/distributor.go @@ -0,0 +1,123 @@ +package middleware + +import ( + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + "strings" + + "github.com/gin-gonic/gin" +) + +type ModelRequest struct { + Model string `json:"model"` +} + +func Distribute() func(c *gin.Context) { + return func(c *gin.Context) { + userId := c.GetInt("id") + userGroup, _ := model.CacheGetUserGroup(userId) + c.Set("group", userGroup) + var channel *model.Channel + channelId, ok := c.Get("channelId") + if ok { + id, err := strconv.Atoi(channelId.(string)) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "无效的渠道 ID", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + channel, err = model.GetChannelById(id, true) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "无效的渠道 ID", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + if channel.Status != common.ChannelStatusEnabled { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "message": "该渠道已被禁用", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + } else { + // Select a channel for the user + var modelRequest ModelRequest + var err error + if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "无效的请求", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + if err != nil { + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + if channel != nil { + common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + message = "数据库一致性已被破坏,请联系管理员" + } + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "message": message, + "type": "one_api_error", + }, + }) + c.Abort() + return + } + } + c.Set("channel", channel.Type) + c.Set("channel_id", channel.Id) + c.Set("channel_name", channel.Name) + c.Set("model_mapping", channel.ModelMapping) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Set("base_url", channel.BaseURL) + if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei { + c.Set("api_version", channel.Other) + } + c.Next() + } +} diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go new file mode 100644 index 0000000000000000000000000000000000000000..8e5cff6cf40a42bb304dd45b8f5dcdc665dcfd33 --- /dev/null +++ b/middleware/rate-limit.go @@ -0,0 +1,103 @@ +package middleware + +import ( + "context" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "time" +) + +var timeFormat = "2006-01-02T15:04:05.000Z" + +var inMemoryRateLimiter common.InMemoryRateLimiter + +func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { + ctx := context.Background() + rdb := common.RDB + key := "rateLimit:" + mark + c.ClientIP() + listLength, err := rdb.LLen(ctx, key).Result() + if err != nil { + fmt.Println(err.Error()) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + if listLength < int64(maxRequestNum) { + rdb.LPush(ctx, key, time.Now().Format(timeFormat)) + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + } else { + oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() + oldTime, err := time.Parse(timeFormat, oldTimeStr) + if err != nil { + fmt.Println(err) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + nowTimeStr := time.Now().Format(timeFormat) + nowTime, err := time.Parse(timeFormat, nowTimeStr) + if err != nil { + fmt.Println(err) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + // time.Since will return negative number! + // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows + if int64(nowTime.Sub(oldTime).Seconds()) < duration { + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } else { + rdb.LPush(ctx, key, time.Now().Format(timeFormat)) + rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + } + } +} + +func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { + key := mark + c.ClientIP() + if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } +} + +func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { + if common.RedisEnabled { + return func(c *gin.Context) { + redisRateLimiter(c, maxRequestNum, duration, mark) + } + } else { + // It's safe to call multi times. + inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + return func(c *gin.Context) { + memoryRateLimiter(c, maxRequestNum, duration, mark) + } + } +} + +func GlobalWebRateLimit() func(c *gin.Context) { + return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") +} + +func GlobalAPIRateLimit() func(c *gin.Context) { + return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") +} + +func CriticalRateLimit() func(c *gin.Context) { + return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") +} + +func DownloadRateLimit() func(c *gin.Context) { + return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") +} + +func UploadRateLimit() func(c *gin.Context) { + return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") +} diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go new file mode 100644 index 0000000000000000000000000000000000000000..26688810d0350845280b29da280e65d7ad3ac3ed --- /dev/null +++ b/middleware/turnstile-check.go @@ -0,0 +1,80 @@ +package middleware + +import ( + "encoding/json" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "net/http" + "net/url" + "one-api/common" +) + +type turnstileCheckResponse struct { + Success bool `json:"success"` +} + +func TurnstileCheck() gin.HandlerFunc { + return func(c *gin.Context) { + if common.TurnstileCheckEnabled { + session := sessions.Default(c) + turnstileChecked := session.Get("turnstile") + if turnstileChecked != nil { + c.Next() + return + } + response := c.Query("turnstile") + if response == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Turnstile token 为空", + }) + c.Abort() + return + } + rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ + "secret": {common.TurnstileSecretKey}, + "response": {response}, + "remoteip": {c.ClientIP()}, + }) + if err != nil { + common.SysError(err.Error()) + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + c.Abort() + return + } + defer rawRes.Body.Close() + var res turnstileCheckResponse + err = json.NewDecoder(rawRes.Body).Decode(&res) + if err != nil { + common.SysError(err.Error()) + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + c.Abort() + return + } + if !res.Success { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Turnstile 校验失败,请刷新重试!", + }) + c.Abort() + return + } + session.Set("turnstile", true) + err = session.Save() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "message": "无法保存会话信息,请重试", + "success": false, + }) + return + } + } + c.Next() + } +} diff --git a/model/ability.go b/model/ability.go new file mode 100644 index 0000000000000000000000000000000000000000..e87c3940909c0ffbdee4c53ad06d6f636bb2701d --- /dev/null +++ b/model/ability.go @@ -0,0 +1,73 @@ +package model + +import ( + "one-api/common" + "strings" +) + +type Ability struct { + Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` + Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` + ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` + Enabled bool `json:"enabled"` +} + +func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { + ability := Ability{} + var err error = nil + if common.UsingSQLite { + err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error + } else { + err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error + } + if err != nil { + return nil, err + } + channel := Channel{} + channel.Id = ability.ChannelId + err = DB.First(&channel, "id = ?", ability.ChannelId).Error + return &channel, err +} + +func (channel *Channel) AddAbilities() error { + models_ := strings.Split(channel.Models, ",") + groups_ := strings.Split(channel.Group, ",") + abilities := make([]Ability, 0, len(models_)) + for _, model := range models_ { + for _, group := range groups_ { + ability := Ability{ + Group: group, + Model: model, + ChannelId: channel.Id, + Enabled: channel.Status == common.ChannelStatusEnabled, + } + abilities = append(abilities, ability) + } + } + return DB.Create(&abilities).Error +} + +func (channel *Channel) DeleteAbilities() error { + return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error +} + +// UpdateAbilities updates abilities of this channel. +// Make sure the channel is completed before calling this function. +func (channel *Channel) UpdateAbilities() error { + // A quick and dirty way to update abilities + // First delete all abilities of this channel + err := channel.DeleteAbilities() + if err != nil { + return err + } + // Then add new abilities + err = channel.AddAbilities() + if err != nil { + return err + } + return nil +} + +func UpdateAbilityStatus(channelId int, status bool) error { + return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error +} diff --git a/model/cache.go b/model/cache.go new file mode 100644 index 0000000000000000000000000000000000000000..55fbba9b05b7e6733e0f533cfaea479216a3a8ff --- /dev/null +++ b/model/cache.go @@ -0,0 +1,183 @@ +package model + +import ( + "encoding/json" + "errors" + "fmt" + "math/rand" + "one-api/common" + "strconv" + "strings" + "sync" + "time" +) + +var ( + TokenCacheSeconds = common.SyncFrequency + UserId2GroupCacheSeconds = common.SyncFrequency + UserId2QuotaCacheSeconds = common.SyncFrequency + UserId2StatusCacheSeconds = common.SyncFrequency +) + +func CacheGetTokenByKey(key string) (*Token, error) { + var token Token + if !common.RedisEnabled { + err := DB.Where("`key` = ?", key).First(&token).Error + return &token, err + } + tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) + if err != nil { + err := DB.Where("`key` = ?", key).First(&token).Error + if err != nil { + return nil, err + } + jsonBytes, err := json.Marshal(token) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) + if err != nil { + common.SysError("Redis set token error: " + err.Error()) + } + return &token, nil + } + err = json.Unmarshal([]byte(tokenObjectString), &token) + return &token, err +} + +func CacheGetUserGroup(id int) (group string, err error) { + if !common.RedisEnabled { + return GetUserGroup(id) + } + group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id)) + if err != nil { + group, err = GetUserGroup(id) + if err != nil { + return "", err + } + err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) + if err != nil { + common.SysError("Redis set user group error: " + err.Error()) + } + } + return group, err +} + +func CacheGetUserQuota(id int) (quota int, err error) { + if !common.RedisEnabled { + return GetUserQuota(id) + } + quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) + if err != nil { + quota, err = GetUserQuota(id) + if err != nil { + return 0, err + } + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + if err != nil { + common.SysError("Redis set user quota error: " + err.Error()) + } + return quota, err + } + quota, err = strconv.Atoi(quotaString) + return quota, err +} + +func CacheUpdateUserQuota(id int) error { + if !common.RedisEnabled { + return nil + } + quota, err := GetUserQuota(id) + if err != nil { + return err + } + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + return err +} + +func CacheDecreaseUserQuota(id int, quota int) error { + if !common.RedisEnabled { + return nil + } + err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota)) + return err +} + +func CacheIsUserEnabled(userId int) bool { + if !common.RedisEnabled { + return IsUserEnabled(userId) + } + enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) + if err != nil { + status := common.UserStatusDisabled + if IsUserEnabled(userId) { + status = common.UserStatusEnabled + } + enabled = fmt.Sprintf("%d", status) + err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) + if err != nil { + common.SysError("Redis set user enabled error: " + err.Error()) + } + } + return enabled == "1" +} + +var group2model2channels map[string]map[string][]*Channel +var channelSyncLock sync.RWMutex + +func InitChannelCache() { + newChannelId2channel := make(map[int]*Channel) + var channels []*Channel + DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) + for _, channel := range channels { + newChannelId2channel[channel.Id] = channel + } + var abilities []*Ability + DB.Find(&abilities) + groups := make(map[string]bool) + for _, ability := range abilities { + groups[ability.Group] = true + } + newGroup2model2channels := make(map[string]map[string][]*Channel) + for group := range groups { + newGroup2model2channels[group] = make(map[string][]*Channel) + } + for _, channel := range channels { + groups := strings.Split(channel.Group, ",") + for _, group := range groups { + models := strings.Split(channel.Models, ",") + for _, model := range models { + if _, ok := newGroup2model2channels[group][model]; !ok { + newGroup2model2channels[group][model] = make([]*Channel, 0) + } + newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel) + } + } + } + channelSyncLock.Lock() + group2model2channels = newGroup2model2channels + channelSyncLock.Unlock() + common.SysLog("channels synced from database") +} + +func SyncChannelCache(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + common.SysLog("syncing channels from database") + InitChannelCache() + } +} + +func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { + if !common.RedisEnabled { + return GetRandomSatisfiedChannel(group, model) + } + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + channels := group2model2channels[group][model] + if len(channels) == 0 { + return nil, errors.New("channel not found") + } + idx := rand.Intn(len(channels)) + return channels[idx], nil +} diff --git a/model/channel.go b/model/channel.go new file mode 100644 index 0000000000000000000000000000000000000000..7cc9fa9b047f3554d52f93eb2d47755e16520275 --- /dev/null +++ b/model/channel.go @@ -0,0 +1,148 @@ +package model + +import ( + "gorm.io/gorm" + "one-api/common" +) + +type Channel struct { + Id int `json:"id"` + Type int `json:"type" gorm:"default:0"` + Key string `json:"key" gorm:"not null;index"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index"` + Weight int `json:"weight"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + TestTime int64 `json:"test_time" gorm:"bigint"` + ResponseTime int `json:"response_time"` // in milliseconds + BaseURL string `json:"base_url" gorm:"column:base_url"` + Other string `json:"other"` + Balance float64 `json:"balance"` // in USD + BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` + Models string `json:"models"` + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` + ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` +} + +func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { + var channels []*Channel + var err error + if selectAll { + err = DB.Order("id desc").Find(&channels).Error + } else { + err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error + } + return channels, err +} + +func SearchChannels(keyword string) (channels []*Channel, err error) { + err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error + return channels, err +} + +func GetChannelById(id int, selectAll bool) (*Channel, error) { + channel := Channel{Id: id} + var err error = nil + if selectAll { + err = DB.First(&channel, "id = ?", id).Error + } else { + err = DB.Omit("key").First(&channel, "id = ?", id).Error + } + return &channel, err +} + +func GetRandomChannel() (*Channel, error) { + channel := Channel{} + var err error = nil + if common.UsingSQLite { + err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error + } else { + err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error + } + return &channel, err +} + +func BatchInsertChannels(channels []Channel) error { + var err error + err = DB.Create(&channels).Error + if err != nil { + return err + } + for _, channel_ := range channels { + err = channel_.AddAbilities() + if err != nil { + return err + } + } + return nil +} + +func (channel *Channel) Insert() error { + var err error + err = DB.Create(channel).Error + if err != nil { + return err + } + err = channel.AddAbilities() + return err +} + +func (channel *Channel) Update() error { + var err error + err = DB.Model(channel).Updates(channel).Error + if err != nil { + return err + } + DB.Model(channel).First(channel, "id = ?", channel.Id) + err = channel.UpdateAbilities() + return err +} + +func (channel *Channel) UpdateResponseTime(responseTime int64) { + err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ + TestTime: common.GetTimestamp(), + ResponseTime: int(responseTime), + }).Error + if err != nil { + common.SysError("failed to update response time: " + err.Error()) + } +} + +func (channel *Channel) UpdateBalance(balance float64) { + err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ + BalanceUpdatedTime: common.GetTimestamp(), + Balance: balance, + }).Error + if err != nil { + common.SysError("failed to update balance: " + err.Error()) + } +} + +func (channel *Channel) Delete() error { + var err error + err = DB.Delete(channel).Error + if err != nil { + return err + } + err = channel.DeleteAbilities() + return err +} + +func UpdateChannelStatusById(id int, status int) { + err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) + if err != nil { + common.SysError("failed to update ability status: " + err.Error()) + } + err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error + if err != nil { + common.SysError("failed to update channel status: " + err.Error()) + } +} + +func UpdateChannelUsedQuota(id int, quota int) { + err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error + if err != nil { + common.SysError("failed to update channel used quota: " + err.Error()) + } +} diff --git a/model/log.go b/model/log.go new file mode 100644 index 0000000000000000000000000000000000000000..b0d6409aa68966dc4757eac464983617f9c352d9 --- /dev/null +++ b/model/log.go @@ -0,0 +1,168 @@ +package model + +import ( + "gorm.io/gorm" + "one-api/common" +) + +type Log struct { + Id int `json:"id"` + UserId int `json:"user_id"` + CreatedAt int64 `json:"created_at" gorm:"bigint;index"` + Type int `json:"type" gorm:"index"` + Content string `json:"content"` + Username string `json:"username" gorm:"index;default:''"` + TokenName string `json:"token_name" gorm:"index;default:''"` + ModelName string `json:"model_name" gorm:"index;default:''"` + Quota int `json:"quota" gorm:"default:0"` + PromptTokens int `json:"prompt_tokens" gorm:"default:0"` + CompletionTokens int `json:"completion_tokens" gorm:"default:0"` +} + +const ( + LogTypeUnknown = iota + LogTypeTopup + LogTypeConsume + LogTypeManage + LogTypeSystem +) + +func RecordLog(userId int, logType int, content string) { + if logType == LogTypeConsume && !common.LogConsumeEnabled { + return + } + log := &Log{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: common.GetTimestamp(), + Type: logType, + Content: content, + } + err := DB.Create(log).Error + if err != nil { + common.SysError("failed to record log: " + err.Error()) + } +} + +func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { + if !common.LogConsumeEnabled { + return + } + log := &Log{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: common.GetTimestamp(), + Type: LogTypeConsume, + Content: content, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TokenName: tokenName, + ModelName: modelName, + Quota: quota, + } + err := DB.Create(log).Error + if err != nil { + common.SysError("failed to record log: " + err.Error()) + } +} + +func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) { + var tx *gorm.DB + if logType == LogTypeUnknown { + tx = DB + } else { + tx = DB.Where("type = ?", logType) + } + if modelName != "" { + tx = tx.Where("model_name = ?", modelName) + } + if username != "" { + tx = tx.Where("username = ?", username) + } + if tokenName != "" { + tx = tx.Where("token_name = ?", tokenName) + } + if startTimestamp != 0 { + tx = tx.Where("created_at >= ?", startTimestamp) + } + if endTimestamp != 0 { + tx = tx.Where("created_at <= ?", endTimestamp) + } + err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error + return logs, err +} + +func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { + var tx *gorm.DB + if logType == LogTypeUnknown { + tx = DB.Where("user_id = ?", userId) + } else { + tx = DB.Where("user_id = ? and type = ?", userId, logType) + } + if modelName != "" { + tx = tx.Where("model_name = ?", modelName) + } + if tokenName != "" { + tx = tx.Where("token_name = ?", tokenName) + } + if startTimestamp != 0 { + tx = tx.Where("created_at >= ?", startTimestamp) + } + if endTimestamp != 0 { + tx = tx.Where("created_at <= ?", endTimestamp) + } + err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error + return logs, err +} + +func SearchAllLogs(keyword string) (logs []*Log, err error) { + err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error + return logs, err +} + +func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { + err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error + return logs, err +} + +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) { + tx := DB.Table("logs").Select("sum(quota)") + if username != "" { + tx = tx.Where("username = ?", username) + } + if tokenName != "" { + tx = tx.Where("token_name = ?", tokenName) + } + if startTimestamp != 0 { + tx = tx.Where("created_at >= ?", startTimestamp) + } + if endTimestamp != 0 { + tx = tx.Where("created_at <= ?", endTimestamp) + } + if modelName != "" { + tx = tx.Where("model_name = ?", modelName) + } + tx.Where("type = ?", LogTypeConsume).Scan("a) + return quota +} + +func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { + tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)") + if username != "" { + tx = tx.Where("username = ?", username) + } + if tokenName != "" { + tx = tx.Where("token_name = ?", tokenName) + } + if startTimestamp != 0 { + tx = tx.Where("created_at >= ?", startTimestamp) + } + if endTimestamp != 0 { + tx = tx.Where("created_at <= ?", endTimestamp) + } + if modelName != "" { + tx = tx.Where("model_name = ?", modelName) + } + tx.Where("type = ?", LogTypeConsume).Scan(&token) + return token +} diff --git a/model/main.go b/model/main.go new file mode 100644 index 0000000000000000000000000000000000000000..d422c4e0acec5e907a078916b868a940a446dba3 --- /dev/null +++ b/model/main.go @@ -0,0 +1,128 @@ +package model + +import ( + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "one-api/common" + "os" + "strings" + "time" +) + +var DB *gorm.DB + +func createRootAccountIfNeed() error { + var user User + //if user.Status != common.UserStatusEnabled { + if err := DB.First(&user).Error; err != nil { + common.SysLog("no user exists, create a root user for you: username is root, password is 123456") + hashedPassword, err := common.Password2Hash("123456") + if err != nil { + return err + } + rootUser := User{ + Username: "root", + Password: hashedPassword, + Role: common.RoleRootUser, + Status: common.UserStatusEnabled, + DisplayName: "Root User", + AccessToken: common.GetUUID(), + Quota: 100000000, + } + DB.Create(&rootUser) + } + return nil +} + +func chooseDB() (*gorm.DB, error) { + if os.Getenv("SQL_DSN") != "" { + dsn := os.Getenv("SQL_DSN") + if strings.HasPrefix(dsn, "postgres://") { + // Use PostgreSQL + common.SysLog("using PostgreSQL as database") + return gorm.Open(postgres.New(postgres.Config{ + DSN: dsn, + PreferSimpleProtocol: true, // disables implicit prepared statement usage + }), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) + } + // Use MySQL + common.SysLog("using MySQL as database") + return gorm.Open(mysql.Open(dsn), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) + } + // Use SQLite + common.SysLog("SQL_DSN not set, using SQLite as database") + common.UsingSQLite = true + return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) +} + +func InitDB() (err error) { + db, err := chooseDB() + if err == nil { + if common.DebugEnabled { + db = db.Debug() + } + DB = db + sqlDB, err := DB.DB() + if err != nil { + return err + } + sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) + + if !common.IsMasterNode { + return nil + } + err = db.AutoMigrate(&Channel{}) + if err != nil { + return err + } + err = db.AutoMigrate(&Token{}) + if err != nil { + return err + } + err = db.AutoMigrate(&User{}) + if err != nil { + return err + } + err = db.AutoMigrate(&Option{}) + if err != nil { + return err + } + err = db.AutoMigrate(&Redemption{}) + if err != nil { + return err + } + err = db.AutoMigrate(&Ability{}) + if err != nil { + return err + } + err = db.AutoMigrate(&Log{}) + if err != nil { + return err + } + common.SysLog("database migrated") + err = createRootAccountIfNeed() + return err + } else { + common.FatalLog(err) + } + return err +} + +func CloseDB() error { + sqlDB, err := DB.DB() + if err != nil { + return err + } + err = sqlDB.Close() + return err +} diff --git a/model/option.go b/model/option.go new file mode 100644 index 0000000000000000000000000000000000000000..4ef4d260fba937c6a2bbd891da66fb1da31f143c --- /dev/null +++ b/model/option.go @@ -0,0 +1,222 @@ +package model + +import ( + "one-api/common" + "strconv" + "strings" + "time" +) + +type Option struct { + Key string `json:"key" gorm:"primaryKey"` + Value string `json:"value"` +} + +func AllOption() ([]*Option, error) { + var options []*Option + var err error + err = DB.Find(&options).Error + return options, err +} + +func InitOptionMap() { + common.OptionMapRWMutex.Lock() + common.OptionMap = make(map[string]string) + common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) + common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) + common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) + common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) + common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) + common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) + common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) + common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) + common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) + common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) + common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) + common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) + common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) + common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) + common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) + common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) + common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) + common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") + common.OptionMap["SMTPServer"] = "" + common.OptionMap["SMTPFrom"] = "" + common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) + common.OptionMap["SMTPAccount"] = "" + common.OptionMap["SMTPToken"] = "" + common.OptionMap["Notice"] = "" + common.OptionMap["About"] = "" + common.OptionMap["HomePageContent"] = "" + common.OptionMap["Footer"] = common.Footer + common.OptionMap["SystemName"] = common.SystemName + common.OptionMap["Logo"] = common.Logo + common.OptionMap["ServerAddress"] = "" + common.OptionMap["GitHubClientId"] = "" + common.OptionMap["GitHubClientSecret"] = "" + common.OptionMap["WeChatServerAddress"] = "" + common.OptionMap["WeChatServerToken"] = "" + common.OptionMap["WeChatAccountQRCodeImageURL"] = "" + common.OptionMap["TurnstileSiteKey"] = "" + common.OptionMap["TurnstileSecretKey"] = "" + common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) + common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) + common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) + common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) + common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) + common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() + common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() + common.OptionMap["TopUpLink"] = common.TopUpLink + common.OptionMap["ChatLink"] = common.ChatLink + common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) + common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) + common.OptionMapRWMutex.Unlock() + loadOptionsFromDatabase() +} + +func loadOptionsFromDatabase() { + options, _ := AllOption() + for _, option := range options { + err := updateOptionMap(option.Key, option.Value) + if err != nil { + common.SysError("failed to update option map: " + err.Error()) + } + } +} + +func SyncOptions(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + common.SysLog("syncing options from database") + loadOptionsFromDatabase() + } +} + +func UpdateOption(key string, value string) error { + // Save to database first + option := Option{ + Key: key, + } + // https://gorm.io/docs/update.html#Save-All-Fields + DB.FirstOrCreate(&option, Option{Key: key}) + option.Value = value + // Save is a combination function. + // If save value does not contain primary key, it will execute Create, + // otherwise it will execute Update (with all fields). + DB.Save(&option) + // Update OptionMap + return updateOptionMap(key, value) +} + +func updateOptionMap(key string, value string) (err error) { + common.OptionMapRWMutex.Lock() + defer common.OptionMapRWMutex.Unlock() + common.OptionMap[key] = value + if strings.HasSuffix(key, "Permission") { + intValue, _ := strconv.Atoi(value) + switch key { + case "FileUploadPermission": + common.FileUploadPermission = intValue + case "FileDownloadPermission": + common.FileDownloadPermission = intValue + case "ImageUploadPermission": + common.ImageUploadPermission = intValue + case "ImageDownloadPermission": + common.ImageDownloadPermission = intValue + } + } + if strings.HasSuffix(key, "Enabled") { + boolValue := value == "true" + switch key { + case "PasswordRegisterEnabled": + common.PasswordRegisterEnabled = boolValue + case "PasswordLoginEnabled": + common.PasswordLoginEnabled = boolValue + case "EmailVerificationEnabled": + common.EmailVerificationEnabled = boolValue + case "GitHubOAuthEnabled": + common.GitHubOAuthEnabled = boolValue + case "WeChatAuthEnabled": + common.WeChatAuthEnabled = boolValue + case "TurnstileCheckEnabled": + common.TurnstileCheckEnabled = boolValue + case "RegisterEnabled": + common.RegisterEnabled = boolValue + case "EmailDomainRestrictionEnabled": + common.EmailDomainRestrictionEnabled = boolValue + case "AutomaticDisableChannelEnabled": + common.AutomaticDisableChannelEnabled = boolValue + case "ApproximateTokenEnabled": + common.ApproximateTokenEnabled = boolValue + case "LogConsumeEnabled": + common.LogConsumeEnabled = boolValue + case "DisplayInCurrencyEnabled": + common.DisplayInCurrencyEnabled = boolValue + case "DisplayTokenStatEnabled": + common.DisplayTokenStatEnabled = boolValue + } + } + switch key { + case "EmailDomainWhitelist": + common.EmailDomainWhitelist = strings.Split(value, ",") + case "SMTPServer": + common.SMTPServer = value + case "SMTPPort": + intValue, _ := strconv.Atoi(value) + common.SMTPPort = intValue + case "SMTPAccount": + common.SMTPAccount = value + case "SMTPFrom": + common.SMTPFrom = value + case "SMTPToken": + common.SMTPToken = value + case "ServerAddress": + common.ServerAddress = value + case "GitHubClientId": + common.GitHubClientId = value + case "GitHubClientSecret": + common.GitHubClientSecret = value + case "Footer": + common.Footer = value + case "SystemName": + common.SystemName = value + case "Logo": + common.Logo = value + case "WeChatServerAddress": + common.WeChatServerAddress = value + case "WeChatServerToken": + common.WeChatServerToken = value + case "WeChatAccountQRCodeImageURL": + common.WeChatAccountQRCodeImageURL = value + case "TurnstileSiteKey": + common.TurnstileSiteKey = value + case "TurnstileSecretKey": + common.TurnstileSecretKey = value + case "QuotaForNewUser": + common.QuotaForNewUser, _ = strconv.Atoi(value) + case "QuotaForInviter": + common.QuotaForInviter, _ = strconv.Atoi(value) + case "QuotaForInvitee": + common.QuotaForInvitee, _ = strconv.Atoi(value) + case "QuotaRemindThreshold": + common.QuotaRemindThreshold, _ = strconv.Atoi(value) + case "PreConsumedQuota": + common.PreConsumedQuota, _ = strconv.Atoi(value) + case "RetryTimes": + common.RetryTimes, _ = strconv.Atoi(value) + case "ModelRatio": + err = common.UpdateModelRatioByJSONString(value) + case "GroupRatio": + err = common.UpdateGroupRatioByJSONString(value) + case "TopUpLink": + common.TopUpLink = value + case "ChatLink": + common.ChatLink = value + case "ChannelDisableThreshold": + common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + case "QuotaPerUnit": + common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + } + return err +} diff --git a/model/redemption.go b/model/redemption.go new file mode 100644 index 0000000000000000000000000000000000000000..fafb21454ecac8c1b759b5bcf157be95b31f4977 --- /dev/null +++ b/model/redemption.go @@ -0,0 +1,111 @@ +package model + +import ( + "errors" + "fmt" + "gorm.io/gorm" + "one-api/common" +) + +type Redemption struct { + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(32);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index"` + Quota int `json:"quota" gorm:"default:100"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` + Count int `json:"count" gorm:"-:all"` // only for api request +} + +func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) { + var redemptions []*Redemption + var err error + err = DB.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error + return redemptions, err +} + +func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) { + err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error + return redemptions, err +} + +func GetRedemptionById(id int) (*Redemption, error) { + if id == 0 { + return nil, errors.New("id 为空!") + } + redemption := Redemption{Id: id} + var err error = nil + err = DB.First(&redemption, "id = ?", id).Error + return &redemption, err +} + +func Redeem(key string, userId int) (quota int, err error) { + if key == "" { + return 0, errors.New("未提供兑换码") + } + if userId == 0 { + return 0, errors.New("无效的 user id") + } + redemption := &Redemption{} + + err = DB.Transaction(func(tx *gorm.DB) error { + err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error + if err != nil { + return errors.New("无效的兑换码") + } + if redemption.Status != common.RedemptionCodeStatusEnabled { + return errors.New("该兑换码已被使用") + } + err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error + if err != nil { + return err + } + redemption.RedeemedTime = common.GetTimestamp() + redemption.Status = common.RedemptionCodeStatusUsed + err = tx.Save(redemption).Error + return err + }) + if err != nil { + return 0, errors.New("兑换失败," + err.Error()) + } + RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) + return redemption.Quota, nil +} + +func (redemption *Redemption) Insert() error { + var err error + err = DB.Create(redemption).Error + return err +} + +func (redemption *Redemption) SelectUpdate() error { + // This can update zero values + return DB.Model(redemption).Select("redeemed_time", "status").Updates(redemption).Error +} + +// Update Make sure your token's fields is completed, because this will update non-zero values +func (redemption *Redemption) Update() error { + var err error + err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error + return err +} + +func (redemption *Redemption) Delete() error { + var err error + err = DB.Delete(redemption).Error + return err +} + +func DeleteRedemptionById(id int) (err error) { + if id == 0 { + return errors.New("id 为空!") + } + redemption := Redemption{Id: id} + err = DB.Where(redemption).First(&redemption).Error + if err != nil { + return err + } + return redemption.Delete() +} diff --git a/model/token.go b/model/token.go new file mode 100644 index 0000000000000000000000000000000000000000..7cd226c641271b218ef4ec07e4058be3ae644206 --- /dev/null +++ b/model/token.go @@ -0,0 +1,227 @@ +package model + +import ( + "errors" + "fmt" + "gorm.io/gorm" + "one-api/common" +) + +type Token struct { + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(48);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index" ` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + AccessedTime int64 `json:"accessed_time" gorm:"bigint"` + ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired + RemainQuota int `json:"remain_quota" gorm:"default:0"` + UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` + UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota +} + +func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { + var tokens []*Token + var err error + err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error + return tokens, err +} + +func SearchUserTokens(userId int, keyword string) (tokens []*Token, err error) { + err = DB.Where("user_id = ?", userId).Where("name LIKE ?", keyword+"%").Find(&tokens).Error + return tokens, err +} + +func ValidateUserToken(key string) (token *Token, err error) { + if key == "" { + return nil, errors.New("未提供令牌") + } + token, err = CacheGetTokenByKey(key) + if err == nil { + if token.Status != common.TokenStatusEnabled { + return nil, errors.New("该令牌状态不可用") + } + if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { + token.Status = common.TokenStatusExpired + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } + return nil, errors.New("该令牌已过期") + } + if !token.UnlimitedQuota && token.RemainQuota <= 0 { + token.Status = common.TokenStatusExhausted + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } + return nil, errors.New("该令牌额度已用尽") + } + go func() { + token.AccessedTime = common.GetTimestamp() + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token" + err.Error()) + } + }() + return token, nil + } + return nil, errors.New("无效的令牌") +} + +func GetTokenByIds(id int, userId int) (*Token, error) { + if id == 0 || userId == 0 { + return nil, errors.New("id 或 userId 为空!") + } + token := Token{Id: id, UserId: userId} + var err error = nil + err = DB.First(&token, "id = ? and user_id = ?", id, userId).Error + return &token, err +} + +func GetTokenById(id int) (*Token, error) { + if id == 0 { + return nil, errors.New("id 为空!") + } + token := Token{Id: id} + var err error = nil + err = DB.First(&token, "id = ?", id).Error + return &token, err +} + +func (token *Token) Insert() error { + var err error + err = DB.Create(token).Error + return err +} + +// Update Make sure your token's fields is completed, because this will update non-zero values +func (token *Token) Update() error { + var err error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + return err +} + +func (token *Token) SelectUpdate() error { + // This can update zero values + return DB.Model(token).Select("accessed_time", "status").Updates(token).Error +} + +func (token *Token) Delete() error { + var err error + err = DB.Delete(token).Error + return err +} + +func DeleteTokenById(id int, userId int) (err error) { + // Why we need userId here? In case user want to delete other's token. + if id == 0 || userId == 0 { + return errors.New("id 或 userId 为空!") + } + token := Token{Id: id, UserId: userId} + err = DB.Where(token).First(&token).Error + if err != nil { + return err + } + return token.Delete() +} + +func IncreaseTokenQuota(id int, quota int) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + err = DB.Model(&Token{}).Where("id = ?", id).Updates( + map[string]interface{}{ + "remain_quota": gorm.Expr("remain_quota + ?", quota), + "used_quota": gorm.Expr("used_quota - ?", quota), + }, + ).Error + return err +} + +func DecreaseTokenQuota(id int, quota int) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + err = DB.Model(&Token{}).Where("id = ?", id).Updates( + map[string]interface{}{ + "remain_quota": gorm.Expr("remain_quota - ?", quota), + "used_quota": gorm.Expr("used_quota + ?", quota), + }, + ).Error + return err +} + +func PreConsumeTokenQuota(tokenId int, quota int) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + token, err := GetTokenById(tokenId) + if err != nil { + return err + } + if !token.UnlimitedQuota && token.RemainQuota < quota { + return errors.New("令牌额度不足") + } + userQuota, err := GetUserQuota(token.UserId) + if err != nil { + return err + } + if userQuota < quota { + return errors.New("用户额度不足") + } + quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold + noMoreQuota := userQuota-quota <= 0 + if quotaTooLow || noMoreQuota { + go func() { + email, err := GetUserEmail(token.UserId) + if err != nil { + common.SysError("failed to fetch user email: " + err.Error()) + } + prompt := "您的额度即将用尽" + if noMoreQuota { + prompt = "您的额度已用尽" + } + if email != "" { + topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) + err = common.SendEmail(prompt, email, + fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink)) + if err != nil { + common.SysError("failed to send email" + err.Error()) + } + } + }() + } + if !token.UnlimitedQuota { + err = DecreaseTokenQuota(tokenId, quota) + if err != nil { + return err + } + } + err = DecreaseUserQuota(token.UserId, quota) + return err +} + +func PostConsumeTokenQuota(tokenId int, quota int) (err error) { + token, err := GetTokenById(tokenId) + if quota > 0 { + err = DecreaseUserQuota(token.UserId, quota) + } else { + err = IncreaseUserQuota(token.UserId, -quota) + } + if err != nil { + return err + } + if !token.UnlimitedQuota { + if quota > 0 { + err = DecreaseTokenQuota(tokenId, quota) + } else { + err = IncreaseTokenQuota(tokenId, -quota) + } + if err != nil { + return err + } + } + return nil +} diff --git a/model/user.go b/model/user.go new file mode 100644 index 0000000000000000000000000000000000000000..7c77184044a9505f317eb1d63856dc25c37402ca --- /dev/null +++ b/model/user.go @@ -0,0 +1,310 @@ +package model + +import ( + "errors" + "fmt" + "gorm.io/gorm" + "one-api/common" + "strings" +) + +// User if you add sensitive fields, don't forget to clean them in setupLogin function. +// Otherwise, the sensitive information will be saved on local storage in plain text! +type User struct { + Id int `json:"id"` + Username string `json:"username" gorm:"unique;index" validate:"max=12"` + Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` + DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` + Role int `json:"role" gorm:"type:int;default:1"` // admin, common + Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled + Email string `json:"email" gorm:"index" validate:"max=50"` + GitHubId string `json:"github_id" gorm:"column:github_id;index"` + WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` + VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! + AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management + Quota int `json:"quota" gorm:"type:int;default:0"` + UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota + RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` + AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` + InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` +} + +func GetMaxUserId() int { + var user User + DB.Last(&user) + return user.Id +} + +func GetAllUsers(startIdx int, num int) (users []*User, err error) { + err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error + return users, err +} + +func SearchUsers(keyword string) (users []*User, err error) { + err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + return users, err +} + +func GetUserById(id int, selectAll bool) (*User, error) { + if id == 0 { + return nil, errors.New("id 为空!") + } + user := User{Id: id} + var err error = nil + if selectAll { + err = DB.First(&user, "id = ?", id).Error + } else { + err = DB.Omit("password").First(&user, "id = ?", id).Error + } + return &user, err +} + +func GetUserIdByAffCode(affCode string) (int, error) { + if affCode == "" { + return 0, errors.New("affCode 为空!") + } + var user User + err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error + return user.Id, err +} + +func DeleteUserById(id int) (err error) { + if id == 0 { + return errors.New("id 为空!") + } + user := User{Id: id} + return user.Delete() +} + +func (user *User) Insert(inviterId int) error { + var err error + if user.Password != "" { + user.Password, err = common.Password2Hash(user.Password) + if err != nil { + return err + } + } + user.Quota = common.QuotaForNewUser + user.AccessToken = common.GetUUID() + user.AffCode = common.GetRandomString(4) + result := DB.Create(user) + if result.Error != nil { + return result.Error + } + if common.QuotaForNewUser > 0 { + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) + } + if inviterId != 0 { + if common.QuotaForInvitee > 0 { + _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) + RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) + } + if common.QuotaForInviter > 0 { + _ = IncreaseUserQuota(inviterId, common.QuotaForInviter) + RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) + } + } + return nil +} + +func (user *User) Update(updatePassword bool) error { + var err error + if updatePassword { + user.Password, err = common.Password2Hash(user.Password) + if err != nil { + return err + } + } + err = DB.Model(user).Updates(user).Error + return err +} + +func (user *User) Delete() error { + if user.Id == 0 { + return errors.New("id 为空!") + } + err := DB.Delete(user).Error + return err +} + +// ValidateAndFill check password & user status +func (user *User) ValidateAndFill() (err error) { + // When querying with struct, GORM will only query with non-zero fields, + // that means if your field’s value is 0, '', false or other zero values, + // it won’t be used to build query conditions + password := user.Password + if user.Username == "" || password == "" { + return errors.New("用户名或密码为空") + } + DB.Where(User{Username: user.Username}).First(user) + okay := common.ValidatePasswordAndHash(password, user.Password) + if !okay || user.Status != common.UserStatusEnabled { + return errors.New("用户名或密码错误,或用户已被封禁") + } + return nil +} + +func (user *User) FillUserById() error { + if user.Id == 0 { + return errors.New("id 为空!") + } + DB.Where(User{Id: user.Id}).First(user) + return nil +} + +func (user *User) FillUserByEmail() error { + if user.Email == "" { + return errors.New("email 为空!") + } + DB.Where(User{Email: user.Email}).First(user) + return nil +} + +func (user *User) FillUserByGitHubId() error { + if user.GitHubId == "" { + return errors.New("GitHub id 为空!") + } + DB.Where(User{GitHubId: user.GitHubId}).First(user) + return nil +} + +func (user *User) FillUserByWeChatId() error { + if user.WeChatId == "" { + return errors.New("WeChat id 为空!") + } + DB.Where(User{WeChatId: user.WeChatId}).First(user) + return nil +} + +func (user *User) FillUserByUsername() error { + if user.Username == "" { + return errors.New("username 为空!") + } + DB.Where(User{Username: user.Username}).First(user) + return nil +} + +func IsEmailAlreadyTaken(email string) bool { + return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1 +} + +func IsWeChatIdAlreadyTaken(wechatId string) bool { + return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 +} + +func IsGitHubIdAlreadyTaken(githubId string) bool { + return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 +} + +func IsUsernameAlreadyTaken(username string) bool { + return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 +} + +func ResetUserPasswordByEmail(email string, password string) error { + if email == "" || password == "" { + return errors.New("邮箱地址或密码为空!") + } + hashedPassword, err := common.Password2Hash(password) + if err != nil { + return err + } + err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error + return err +} + +func IsAdmin(userId int) bool { + if userId == 0 { + return false + } + var user User + err := DB.Where("id = ?", userId).Select("role").Find(&user).Error + if err != nil { + common.SysError("no such user " + err.Error()) + return false + } + return user.Role >= common.RoleAdminUser +} + +func IsUserEnabled(userId int) bool { + if userId == 0 { + return false + } + var user User + err := DB.Where("id = ?", userId).Select("status").Find(&user).Error + if err != nil { + common.SysError("no such user " + err.Error()) + return false + } + return user.Status == common.UserStatusEnabled +} + +func ValidateAccessToken(token string) (user *User) { + if token == "" { + return nil + } + token = strings.Replace(token, "Bearer ", "", 1) + user = &User{} + if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 { + return user + } + return nil +} + +func GetUserQuota(id int) (quota int, err error) { + err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error + return quota, err +} + +func GetUserUsedQuota(id int) (quota int, err error) { + err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error + return quota, err +} + +func GetUserEmail(id int) (email string, err error) { + err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error + return email, err +} + +func GetUserGroup(id int) (group string, err error) { + err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error + return group, err +} + +func IncreaseUserQuota(id int, quota int) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error + return err +} + +func DecreaseUserQuota(id int, quota int) (err error) { + if quota < 0 { + return errors.New("quota 不能为负数!") + } + err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error + return err +} + +func GetRootUserEmail() (email string) { + DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) + return email +} + +func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { + err := DB.Model(&User{}).Where("id = ?", id).Updates( + map[string]interface{}{ + "used_quota": gorm.Expr("used_quota + ?", quota), + "request_count": gorm.Expr("request_count + ?", 1), + }, + ).Error + if err != nil { + common.SysError("failed to update user used quota and request count: " + err.Error()) + } +} + +func GetUsernameById(id int) (username string) { + DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) + return username +} diff --git a/one-api.service b/one-api.service new file mode 100644 index 0000000000000000000000000000000000000000..17e236bc5e176343a44729c01d01d2110bee69fa --- /dev/null +++ b/one-api.service @@ -0,0 +1,18 @@ +# File path: /etc/systemd/system/one-api.service +# sudo systemctl daemon-reload +# sudo systemctl start one-api +# sudo systemctl enable one-api +# sudo systemctl status one-api +[Unit] +Description=One API Service +After=network.target + +[Service] +User=ubuntu # 注意修改用户名 +WorkingDirectory=/path/to/one-api # 注意修改路径 +ExecStart=/path/to/one-api/one-api --port 3000 --log-dir /path/to/one-api/logs # 注意修改路径和端口号 +Restart=always +RestartSec=5 + +[Install] +WantedBy=multi-user.target diff --git a/router/api-router.go b/router/api-router.go new file mode 100644 index 0000000000000000000000000000000000000000..cc330d7e9edf2f7b08d4683bfb87bcf7d9af9509 --- /dev/null +++ b/router/api-router.go @@ -0,0 +1,111 @@ +package router + +import ( + "one-api/controller" + "one-api/middleware" + + "github.com/gin-contrib/gzip" + "github.com/gin-gonic/gin" +) + +func SetApiRouter(router *gin.Engine) { + apiRouter := router.Group("/api") + apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) + apiRouter.Use(middleware.GlobalAPIRateLimit()) + { + apiRouter.GET("/status", controller.GetStatus) + apiRouter.GET("/notice", controller.GetNotice) + apiRouter.GET("/about", controller.GetAbout) + apiRouter.GET("/home_page_content", controller.GetHomePageContent) + apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) + apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) + apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) + apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) + apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) + apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) + apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) + + userRoute := apiRouter.Group("/user") + { + userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register) + userRoute.POST("/login", middleware.CriticalRateLimit(), controller.Login) + userRoute.GET("/logout", controller.Logout) + + selfRoute := userRoute.Group("/") + selfRoute.Use(middleware.UserAuth()) + { + selfRoute.GET("/self", controller.GetSelf) + selfRoute.PUT("/self", controller.UpdateSelf) + selfRoute.DELETE("/self", controller.DeleteSelf) + selfRoute.GET("/token", controller.GenerateAccessToken) + selfRoute.GET("/aff", controller.GetAffCode) + selfRoute.POST("/topup", controller.TopUp) + } + + adminRoute := userRoute.Group("/") + adminRoute.Use(middleware.AdminAuth()) + { + adminRoute.GET("/", controller.GetAllUsers) + adminRoute.GET("/search", controller.SearchUsers) + adminRoute.GET("/:id", controller.GetUser) + adminRoute.POST("/", controller.CreateUser) + adminRoute.POST("/manage", controller.ManageUser) + adminRoute.PUT("/", controller.UpdateUser) + adminRoute.DELETE("/:id", controller.DeleteUser) + } + } + optionRoute := apiRouter.Group("/option") + optionRoute.Use(middleware.RootAuth()) + { + optionRoute.GET("/", controller.GetOptions) + optionRoute.PUT("/", controller.UpdateOption) + } + channelRoute := apiRouter.Group("/channel") + channelRoute.Use(middleware.AdminAuth()) + { + channelRoute.GET("/", controller.GetAllChannels) + channelRoute.GET("/search", controller.SearchChannels) + channelRoute.GET("/models", controller.ListModels) + channelRoute.GET("/:id", controller.GetChannel) + channelRoute.GET("/test", controller.TestAllChannels) + channelRoute.GET("/test/:id", controller.TestChannel) + channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) + channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) + channelRoute.POST("/", controller.AddChannel) + channelRoute.PUT("/", controller.UpdateChannel) + channelRoute.DELETE("/:id", controller.DeleteChannel) + } + tokenRoute := apiRouter.Group("/token") + tokenRoute.Use(middleware.UserAuth()) + { + tokenRoute.GET("/", controller.GetAllTokens) + tokenRoute.GET("/search", controller.SearchTokens) + tokenRoute.GET("/:id", controller.GetToken) + tokenRoute.POST("/", controller.AddToken) + tokenRoute.PUT("/", controller.UpdateToken) + tokenRoute.DELETE("/:id", controller.DeleteToken) + } + redemptionRoute := apiRouter.Group("/redemption") + redemptionRoute.Use(middleware.AdminAuth()) + { + redemptionRoute.GET("/", controller.GetAllRedemptions) + redemptionRoute.GET("/search", controller.SearchRedemptions) + redemptionRoute.GET("/:id", controller.GetRedemption) + redemptionRoute.POST("/", controller.AddRedemption) + redemptionRoute.PUT("/", controller.UpdateRedemption) + redemptionRoute.DELETE("/:id", controller.DeleteRedemption) + } + logRoute := apiRouter.Group("/log") + logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs) + logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat) + logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat) + logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) + logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) + logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) + groupRoute := apiRouter.Group("/group") + groupRoute.Use(middleware.AdminAuth()) + { + groupRoute.GET("/", controller.GetGroups) + } + } +} diff --git a/router/dashboard.go b/router/dashboard.go new file mode 100644 index 0000000000000000000000000000000000000000..39ed1f93e53bae12e039d0b3490abaa87c8c228f --- /dev/null +++ b/router/dashboard.go @@ -0,0 +1,21 @@ +package router + +import ( + "github.com/gin-contrib/gzip" + "github.com/gin-gonic/gin" + "one-api/controller" + "one-api/middleware" +) + +func SetDashboardRouter(router *gin.Engine) { + apiRouter := router.Group("/") + apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) + apiRouter.Use(middleware.GlobalAPIRateLimit()) + apiRouter.Use(middleware.TokenAuth()) + { + apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription) + apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription) + apiRouter.GET("/dashboard/billing/usage", controller.GetUsage) + apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage) + } +} diff --git a/router/main.go b/router/main.go new file mode 100644 index 0000000000000000000000000000000000000000..b8ac40555febfd0ae2c31585a01797b44218a5f2 --- /dev/null +++ b/router/main.go @@ -0,0 +1,30 @@ +package router + +import ( + "embed" + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "os" + "strings" +) + +func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { + SetApiRouter(router) + SetDashboardRouter(router) + SetRelayRouter(router) + frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") + if common.IsMasterNode && frontendBaseUrl != "" { + frontendBaseUrl = "" + common.SysLog("FRONTEND_BASE_URL is ignored on master node") + } + if frontendBaseUrl == "" { + SetWebRouter(router, buildFS, indexPage) + } else { + frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/") + router.NoRoute(func(c *gin.Context) { + c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI)) + }) + } +} diff --git a/router/relay-router.go b/router/relay-router.go new file mode 100644 index 0000000000000000000000000000000000000000..a76e42cf69d206023a61af1f37b0deb3917d0e39 --- /dev/null +++ b/router/relay-router.go @@ -0,0 +1,44 @@ +package router + +import ( + "one-api/controller" + "one-api/middleware" + + "github.com/gin-gonic/gin" +) + +func SetRelayRouter(router *gin.Engine) { + // https://platform.openai.com/docs/api-reference/introduction + modelsRouter := router.Group("/v1/models") + modelsRouter.Use(middleware.TokenAuth()) + { + modelsRouter.GET("", controller.ListModels) + modelsRouter.GET("/:model", controller.RetrieveModel) + } + relayV1Router := router.Group("/v1") + relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + { + relayV1Router.POST("/completions", controller.Relay) + relayV1Router.POST("/chat/completions", controller.Relay) + relayV1Router.POST("/edits", controller.Relay) + relayV1Router.POST("/images/generations", controller.Relay) + relayV1Router.POST("/images/edits", controller.RelayNotImplemented) + relayV1Router.POST("/images/variations", controller.RelayNotImplemented) + relayV1Router.POST("/embeddings", controller.Relay) + relayV1Router.POST("/engines/:model/embeddings", controller.Relay) + relayV1Router.POST("/audio/transcriptions", controller.Relay) + relayV1Router.POST("/audio/translations", controller.Relay) + relayV1Router.GET("/files", controller.RelayNotImplemented) + relayV1Router.POST("/files", controller.RelayNotImplemented) + relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) + relayV1Router.GET("/files/:id", controller.RelayNotImplemented) + relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) + relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) + relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) + relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) + relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) + relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) + relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) + relayV1Router.POST("/moderations", controller.Relay) + } +} diff --git a/router/web-router.go b/router/web-router.go new file mode 100644 index 0000000000000000000000000000000000000000..8f9c18a2ed6d9f7c8aeabe1eb7b6a7025d8794f1 --- /dev/null +++ b/router/web-router.go @@ -0,0 +1,28 @@ +package router + +import ( + "embed" + "github.com/gin-contrib/gzip" + "github.com/gin-contrib/static" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" + "one-api/controller" + "one-api/middleware" + "strings" +) + +func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { + router.Use(gzip.Gzip(gzip.DefaultCompression)) + router.Use(middleware.GlobalWebRateLimit()) + router.Use(middleware.Cache()) + router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build"))) + router.NoRoute(func(c *gin.Context) { + if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") { + controller.RelayNotFound(c) + return + } + c.Header("Cache-Control", "no-cache") + c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage) + }) +} diff --git a/web/.gitignore b/web/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2b5bba767be29f53c9efa0b00b8d7f61059bd5d2 --- /dev/null +++ b/web/.gitignore @@ -0,0 +1,26 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# production +/build + +# misc +.DS_Store +.env.local +.env.development.local +.env.test.local +.env.production.local + +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.idea +package-lock.json +yarn.lock \ No newline at end of file diff --git a/web/README.md b/web/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1b1031a3a2a50bce6205a1758b8000cb3f1cf693 --- /dev/null +++ b/web/README.md @@ -0,0 +1,21 @@ +# React Template + +## Basic Usages + +```shell +# Runs the app in the development mode +npm start + +# Builds the app for production to the `build` folder +npm run build +``` + +If you want to change the default server, please set `REACT_APP_SERVER` environment variables before build, +for example: `REACT_APP_SERVER=http://your.domain.com`. + +Before you start editing, make sure your `Actions on Save` options have `Optimize imports` & `Run Prettier` enabled. + +## Reference + +1. https://github.com/OIerDb-ng/OIerDb +2. https://github.com/cornflourblue/react-hooks-redux-registration-login-example \ No newline at end of file diff --git a/web/package.json b/web/package.json new file mode 100644 index 0000000000000000000000000000000000000000..a2bf3054761c72b529ba763aac8192396fb89c3b --- /dev/null +++ b/web/package.json @@ -0,0 +1,51 @@ +{ + "name": "react-template", + "version": "0.1.0", + "private": true, + "dependencies": { + "axios": "^0.27.2", + "history": "^5.3.0", + "marked": "^4.1.1", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "react-dropzone": "^14.2.3", + "react-router-dom": "^6.3.0", + "react-scripts": "5.0.1", + "react-toastify": "^9.0.8", + "react-turnstile": "^1.0.5", + "semantic-ui-css": "^2.5.0", + "semantic-ui-react": "^2.1.3" + }, + "scripts": { + "start": "react-scripts start", + "build": "react-scripts build", + "test": "react-scripts test", + "eject": "react-scripts eject" + }, + "eslintConfig": { + "extends": [ + "react-app", + "react-app/jest" + ] + }, + "browserslist": { + "production": [ + ">0.2%", + "not dead", + "not op_mini all" + ], + "development": [ + "last 1 chrome version", + "last 1 firefox version", + "last 1 safari version" + ] + }, + "devDependencies": { + "prettier": "^2.7.1" + }, + "prettier": { + "singleQuote": true, + "jsxSingleQuote": true + }, + "proxy": "http://localhost:3000" +} diff --git a/web/public/favicon.ico b/web/public/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..c2c8de0c5435fe2ffd94ef6da10fa2662cd9ea17 Binary files /dev/null and b/web/public/favicon.ico differ diff --git a/web/public/index.html b/web/public/index.html new file mode 100644 index 0000000000000000000000000000000000000000..b8e324d214ec1ab124a57102ab956ff171263b0c --- /dev/null +++ b/web/public/index.html @@ -0,0 +1,18 @@ + + + + + + + + + One API + + + +
+ + diff --git a/web/public/logo.png b/web/public/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..0f237a226583e08f89f14a15d86aa330a11151ae Binary files /dev/null and b/web/public/logo.png differ diff --git a/web/public/robots.txt b/web/public/robots.txt new file mode 100644 index 0000000000000000000000000000000000000000..e9e57dc4d41b9b46e05112e9f45b7ea6ac0ba15e --- /dev/null +++ b/web/public/robots.txt @@ -0,0 +1,3 @@ +# https://www.robotstxt.org/robotstxt.html +User-agent: * +Disallow: diff --git a/web/src/App.js b/web/src/App.js new file mode 100644 index 0000000000000000000000000000000000000000..c967ce2c10861170cf98d90ee482386dd2b65ff9 --- /dev/null +++ b/web/src/App.js @@ -0,0 +1,291 @@ +import React, { lazy, Suspense, useContext, useEffect } from 'react'; +import { Route, Routes } from 'react-router-dom'; +import Loading from './components/Loading'; +import User from './pages/User'; +import { PrivateRoute } from './components/PrivateRoute'; +import RegisterForm from './components/RegisterForm'; +import LoginForm from './components/LoginForm'; +import NotFound from './pages/NotFound'; +import Setting from './pages/Setting'; +import EditUser from './pages/User/EditUser'; +import AddUser from './pages/User/AddUser'; +import { API, getLogo, getSystemName, showError, showNotice } from './helpers'; +import PasswordResetForm from './components/PasswordResetForm'; +import GitHubOAuth from './components/GitHubOAuth'; +import PasswordResetConfirm from './components/PasswordResetConfirm'; +import { UserContext } from './context/User'; +import { StatusContext } from './context/Status'; +import Channel from './pages/Channel'; +import Token from './pages/Token'; +import EditToken from './pages/Token/EditToken'; +import EditChannel from './pages/Channel/EditChannel'; +import Redemption from './pages/Redemption'; +import EditRedemption from './pages/Redemption/EditRedemption'; +import TopUp from './pages/TopUp'; +import Log from './pages/Log'; +import Chat from './pages/Chat'; + +const Home = lazy(() => import('./pages/Home')); +const About = lazy(() => import('./pages/About')); + +function App() { + const [userState, userDispatch] = useContext(UserContext); + const [statusState, statusDispatch] = useContext(StatusContext); + + const loadUser = () => { + let user = localStorage.getItem('user'); + if (user) { + let data = JSON.parse(user); + userDispatch({ type: 'login', payload: data }); + } + }; + const loadStatus = async () => { + const res = await API.get('/api/status'); + const { success, data } = res.data; + if (success) { + localStorage.setItem('status', JSON.stringify(data)); + statusDispatch({ type: 'set', payload: data }); + localStorage.setItem('system_name', data.system_name); + localStorage.setItem('logo', data.logo); + localStorage.setItem('footer_html', data.footer_html); + localStorage.setItem('quota_per_unit', data.quota_per_unit); + localStorage.setItem('display_in_currency', data.display_in_currency); + if (data.chat_link) { + localStorage.setItem('chat_link', data.chat_link); + } else { + localStorage.removeItem('chat_link'); + } + if ( + data.version !== process.env.REACT_APP_VERSION && + data.version !== 'v0.0.0' && + process.env.REACT_APP_VERSION !== '' + ) { + showNotice( + `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面` + ); + } + } else { + showError('无法正常连接至服务器!'); + } + }; + + useEffect(() => { + loadUser(); + loadStatus().then(); + let systemName = getSystemName(); + if (systemName) { + document.title = systemName; + } + let logo = getLogo(); + if (logo) { + let linkElement = document.querySelector("link[rel~='icon']"); + if (linkElement) { + linkElement.href = logo; + } + } + }, []); + + return ( + + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + + }> + + + + } + /> + + }> + + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + ); +} + +export default App; diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js new file mode 100644 index 0000000000000000000000000000000000000000..5eb397832dcd0a1ac60e65eba4d60ccf76fb85f8 --- /dev/null +++ b/web/src/components/ChannelsTable.js @@ -0,0 +1,472 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; +import { Link } from 'react-router-dom'; +import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; + +import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; +import { renderGroup, renderNumber } from '../helpers/render'; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +let type2label = undefined; + +function renderType(type) { + if (!type2label) { + type2label = new Map; + for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { + type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i]; + } + type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; + } + return ; +} + +function renderBalance(type, balance) { + switch (type) { + case 1: // OpenAI + return ${balance.toFixed(2)}; + case 4: // CloseAI + return ¥{balance.toFixed(2)}; + case 8: // 自定义 + return ${balance.toFixed(2)}; + case 5: // OpenAI-SB + return ¥{(balance / 10000).toFixed(2)}; + case 10: // AI Proxy + return {renderNumber(balance)}; + case 12: // API2GPT + return ¥{balance.toFixed(2)}; + case 13: // AIGC2D + return {renderNumber(balance)}; + default: + return 不支持; + } +} + +const ChannelsTable = () => { + const [channels, setChannels] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [updatingBalance, setUpdatingBalance] = useState(false); + + const loadChannels = async (startIdx) => { + const res = await API.get(`/api/channel/?p=${startIdx}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setChannels(data); + } else { + let newChannels = [...channels]; + newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); + setChannels(newChannels); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(channels.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadChannels(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + const refresh = async () => { + setLoading(true); + await loadChannels(activePage - 1); + }; + + useEffect(() => { + loadChannels(0) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const manageChannel = async (id, action, idx) => { + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/channel/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/channel/', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/channel/', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let channel = res.data.data; + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + newChannels[realIdx].deleted = true; + } else { + newChannels[realIdx].status = channel.status; + } + setChannels(newChannels); + } else { + showError(message); + } + }; + + const renderStatus = (status) => { + switch (status) { + case 1: + return ; + case 2: + return ( + + ); + default: + return ( + + ); + } + }; + + const renderResponseTime = (responseTime) => { + let time = responseTime / 1000; + time = time.toFixed(2) + ' 秒'; + if (responseTime === 0) { + return ; + } else if (responseTime <= 1000) { + return ; + } else if (responseTime <= 3000) { + return ; + } else if (responseTime <= 5000) { + return ; + } else { + return ; + } + }; + + const searchChannels = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadChannels(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/channel/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setChannels(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const testChannel = async (id, name, idx) => { + const res = await API.get(`/api/channel/test/${id}/`); + const { success, message, time } = res.data; + if (success) { + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].response_time = time * 1000; + newChannels[realIdx].test_time = Date.now() / 1000; + setChannels(newChannels); + showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + } else { + showError(message); + } + }; + + const testAllChannels = async () => { + const res = await API.get(`/api/channel/test`); + const { success, message } = res.data; + if (success) { + showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。'); + } else { + showError(message); + } + }; + + const updateChannelBalance = async (id, name, idx) => { + const res = await API.get(`/api/channel/update_balance/${id}/`); + const { success, message, balance } = res.data; + if (success) { + let newChannels = [...channels]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + newChannels[realIdx].balance = balance; + newChannels[realIdx].balance_updated_time = Date.now() / 1000; + setChannels(newChannels); + showInfo(`通道 ${name} 余额更新成功!`); + } else { + showError(message); + } + }; + + const updateAllChannelsBalance = async () => { + setUpdatingBalance(true); + const res = await API.get(`/api/channel/update_balance`); + const { success, message } = res.data; + if (success) { + showInfo('已更新完毕所有已启用通道余额!'); + } else { + showError(message); + } + setUpdatingBalance(false); + }; + + const handleKeywordChange = async (e, { value }) => { + setSearchKeyword(value.trim()); + }; + + const sortChannel = (key) => { + if (channels.length === 0) return; + setLoading(true); + let sortedChannels = [...channels]; + if (typeof sortedChannels[0][key] === 'string') { + sortedChannels.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + } else { + sortedChannels.sort((a, b) => { + if (a[key] === b[key]) return 0; + if (a[key] > b[key]) return -1; + if (a[key] < b[key]) return 1; + }); + } + if (sortedChannels[0].id === channels[0].id) { + sortedChannels.reverse(); + } + setChannels(sortedChannels); + setLoading(false); + }; + + return ( + <> +
+ + + + + + + { + sortChannel('id'); + }} + > + ID + + { + sortChannel('name'); + }} + > + 名称 + + { + sortChannel('group'); + }} + > + 分组 + + { + sortChannel('type'); + }} + > + 类型 + + { + sortChannel('status'); + }} + > + 状态 + + { + sortChannel('response_time'); + }} + > + 响应时间 + + { + sortChannel('balance'); + }} + > + 余额 + + 操作 + + + + + {channels + .slice( + (activePage - 1) * ITEMS_PER_PAGE, + activePage * ITEMS_PER_PAGE + ) + .map((channel, idx) => { + if (channel.deleted) return <>; + return ( + + {channel.id} + {channel.name ? channel.name : '无'} + {renderGroup(channel.group)} + {renderType(channel.type)} + {renderStatus(channel.status)} + + + + + { + updateChannelBalance(channel.id, channel.name, idx); + }} style={{ cursor: 'pointer' }}> + {renderBalance(channel.type, channel.balance)} + } + content='点击更新' + basic + /> + + +
+ + {/* {*/} + {/* updateChannelBalance(channel.id, channel.name, idx);*/} + {/* }}*/} + {/*>*/} + {/* 更新余额*/} + {/**/} + + 删除 + + } + on='click' + flowing + hoverable + > + + + + +
+
+
+ ); + })} +
+ + + + + + + + + + + + +
+ + ); +}; + +export default ChannelsTable; diff --git a/web/src/components/Footer.js b/web/src/components/Footer.js new file mode 100644 index 0000000000000000000000000000000000000000..334ee379cf5f80124abbf4226c71472160ba1323 --- /dev/null +++ b/web/src/components/Footer.js @@ -0,0 +1,61 @@ +import React, { useEffect, useState } from 'react'; + +import { Container, Segment } from 'semantic-ui-react'; +import { getFooterHTML, getSystemName } from '../helpers'; + +const Footer = () => { + const systemName = getSystemName(); + const [footer, setFooter] = useState(getFooterHTML()); + let remainCheckTimes = 5; + + const loadFooter = () => { + let footer_html = localStorage.getItem('footer_html'); + if (footer_html) { + setFooter(footer_html); + } + }; + + useEffect(() => { + const timer = setInterval(() => { + if (remainCheckTimes <= 0) { + clearInterval(timer); + return; + } + remainCheckTimes--; + loadFooter(); + }, 200); + return () => clearTimeout(timer); + }, []); + + return ( + + + {footer ? ( +
+ ) : ( + + )} +
+
+ ); +}; + +export default Footer; diff --git a/web/src/components/GitHubOAuth.js b/web/src/components/GitHubOAuth.js new file mode 100644 index 0000000000000000000000000000000000000000..147d4d30cbda98efd63e4e52f33793b6ab9f3755 --- /dev/null +++ b/web/src/components/GitHubOAuth.js @@ -0,0 +1,57 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Dimmer, Loader, Segment } from 'semantic-ui-react'; +import { useNavigate, useSearchParams } from 'react-router-dom'; +import { API, showError, showSuccess } from '../helpers'; +import { UserContext } from '../context/User'; + +const GitHubOAuth = () => { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, count) => { + const res = await API.get(`/api/oauth/github?code=${code}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind GitHub + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + sendCode(code, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default GitHubOAuth; diff --git a/web/src/components/Header.js b/web/src/components/Header.js new file mode 100644 index 0000000000000000000000000000000000000000..21ebcab66479ad3dcd6ca366e64cfe4aac211dec --- /dev/null +++ b/web/src/components/Header.js @@ -0,0 +1,223 @@ +import React, { useContext, useState } from 'react'; +import { Link, useNavigate } from 'react-router-dom'; +import { UserContext } from '../context/User'; + +import { Button, Container, Dropdown, Icon, Menu, Segment } from 'semantic-ui-react'; +import { API, getLogo, getSystemName, isAdmin, isMobile, showSuccess } from '../helpers'; +import '../index.css'; + +// Header Buttons +let headerButtons = [ + { + name: '首页', + to: '/', + icon: 'home' + }, + { + name: '渠道', + to: '/channel', + icon: 'sitemap', + admin: true + }, + { + name: '令牌', + to: '/token', + icon: 'key' + }, + { + name: '兑换', + to: '/redemption', + icon: 'dollar sign', + admin: true + }, + { + name: '充值', + to: '/topup', + icon: 'cart' + }, + { + name: '用户', + to: '/user', + icon: 'user', + admin: true + }, + { + name: '日志', + to: '/log', + icon: 'book' + }, + { + name: '设置', + to: '/setting', + icon: 'setting' + }, + { + name: '关于', + to: '/about', + icon: 'info circle' + } +]; + +if (localStorage.getItem('chat_link')) { + headerButtons.splice(1, 0, { + name: '聊天', + to: '/chat', + icon: 'comments' + }); +} + +const Header = () => { + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); + + const [showSidebar, setShowSidebar] = useState(false); + const systemName = getSystemName(); + const logo = getLogo(); + + async function logout() { + setShowSidebar(false); + await API.get('/api/user/logout'); + showSuccess('注销成功!'); + userDispatch({ type: 'logout' }); + localStorage.removeItem('user'); + navigate('/login'); + } + + const toggleSidebar = () => { + setShowSidebar(!showSidebar); + }; + + const renderButtons = (isMobile) => { + return headerButtons.map((button) => { + if (button.admin && !isAdmin()) return <>; + if (isMobile) { + return ( + { + navigate(button.to); + setShowSidebar(false); + }} + > + {button.name} + + ); + } + return ( + + + {button.name} + + ); + }); + }; + + if (isMobile()) { + return ( + <> + + + + logo +
+ {systemName} +
+
+ + + + + +
+
+ {showSidebar ? ( + + + {renderButtons(true)} + + {userState.user ? ( + + ) : ( + <> + + + + )} + + + + ) : ( + <> + )} + + ); + } + + return ( + <> + + + + logo +
+ {systemName} +
+
+ {renderButtons(false)} + + {userState.user ? ( + + + 注销 + + + ) : ( + + )} + +
+
+ + ); +}; + +export default Header; diff --git a/web/src/components/Loading.js b/web/src/components/Loading.js new file mode 100644 index 0000000000000000000000000000000000000000..1210a56f3145441b13e3c3a2c2f6eab682c2d76c --- /dev/null +++ b/web/src/components/Loading.js @@ -0,0 +1,14 @@ +import React from 'react'; +import { Segment, Dimmer, Loader } from 'semantic-ui-react'; + +const Loading = ({ prompt: name = 'page' }) => { + return ( + + + 加载{name}中... + + + ); +}; + +export default Loading; diff --git a/web/src/components/LoginForm.js b/web/src/components/LoginForm.js new file mode 100644 index 0000000000000000000000000000000000000000..110dad4687e6124d699125c29b98087b63c3a236 --- /dev/null +++ b/web/src/components/LoginForm.js @@ -0,0 +1,192 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; +import { Link, useNavigate, useSearchParams } from 'react-router-dom'; +import { UserContext } from '../context/User'; +import { API, getLogo, showError, showSuccess } from '../helpers'; + +const LoginForm = () => { + const [inputs, setInputs] = useState({ + username: '', + password: '', + wechat_verification_code: '' + }); + const [searchParams, setSearchParams] = useSearchParams(); + const [submitted, setSubmitted] = useState(false); + const { username, password } = inputs; + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); + const [status, setStatus] = useState({}); + const logo = getLogo(); + + useEffect(() => { + if (searchParams.get('expired')) { + showError('未登录或登录已过期,请重新登录!'); + } + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setStatus(status); + } + }, []); + + const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); + + const onGitHubOAuthClicked = () => { + window.open( + `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` + ); + }; + + const onWeChatLoginClicked = () => { + setShowWeChatLoginModal(true); + }; + + const onSubmitWeChatVerificationCode = async () => { + const res = await API.get( + `/api/oauth/wechat?code=${inputs.wechat_verification_code}` + ); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + navigate('/'); + showSuccess('登录成功!'); + setShowWeChatLoginModal(false); + } else { + showError(message); + } + }; + + function handleChange(e) { + const { name, value } = e.target; + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + setSubmitted(true); + if (username && password) { + const res = await API.post(`/api/user/login`, { + username, + password + }); + const { success, message, data } = res.data; + if (success) { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + navigate('/'); + showSuccess('登录成功!'); + } else { + showError(message); + } + } + } + + return ( + + +
+ 用户登录 +
+
+ + + + + +
+ + 忘记密码? + + 点击重置 + + ; 没有账户? + + 点击注册 + + + {status.github_oauth || status.wechat_login ? ( + <> + Or + {status.github_oauth ? ( + + + + + +
+
+ ); +}; + +export default LoginForm; diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js new file mode 100644 index 0000000000000000000000000000000000000000..bacb7689c12df84301f9b3203e4baaff47a16828 --- /dev/null +++ b/web/src/components/LogsTable.js @@ -0,0 +1,380 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Header, Label, Pagination, Segment, Select, Table } from 'semantic-ui-react'; +import { API, isAdmin, showError, timestamp2string } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderQuota } from '../helpers/render'; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +const MODE_OPTIONS = [ + { key: 'all', text: '全部用户', value: 'all' }, + { key: 'self', text: '当前用户', value: 'self' } +]; + +const LOG_OPTIONS = [ + { key: '0', text: '全部', value: 0 }, + { key: '1', text: '充值', value: 1 }, + { key: '2', text: '消费', value: 2 }, + { key: '3', text: '管理', value: 3 }, + { key: '4', text: '系统', value: 4 } +]; + +function renderType(type) { + switch (type) { + case 1: + return ; + case 2: + return ; + case 3: + return ; + case 4: + return ; + default: + return ; + } +} + +const LogsTable = () => { + const [logs, setLogs] = useState([]); + const [showStat, setShowStat] = useState(false); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [logType, setLogType] = useState(0); + const isAdminUser = isAdmin(); + let now = new Date(); + const [inputs, setInputs] = useState({ + username: '', + token_name: '', + model_name: '', + start_timestamp: timestamp2string(0), + end_timestamp: timestamp2string(now.getTime() / 1000 + 3600) + }); + const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs; + + const [stat, setStat] = useState({ + quota: 0, + token: 0 + }); + + const handleInputChange = (e, { name, value }) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const getLogSelfStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const getLogStat = async () => { + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`); + const { success, message, data } = res.data; + if (success) { + setStat(data); + } else { + showError(message); + } + }; + + const handleEyeClick = async () => { + if (!showStat) { + if (isAdminUser) { + await getLogStat(); + } else { + await getLogSelfStat(); + } + } + setShowStat(!showStat); + }; + + const loadLogs = async (startIdx) => { + let url = ''; + let localStartTimestamp = Date.parse(start_timestamp) / 1000; + let localEndTimestamp = Date.parse(end_timestamp) / 1000; + if (isAdminUser) { + url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } else { + url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogs(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); + setLogs(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadLogs(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + const refresh = async () => { + setLoading(true); + setActivePage(1); + await loadLogs(0); + }; + + useEffect(() => { + refresh().then(); + }, [logType]); + + const searchLogs = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadLogs(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/log/self/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setLogs(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (e, { value }) => { + setSearchKeyword(value.trim()); + }; + + const sortLog = (key) => { + if (logs.length === 0) return; + setLoading(true); + let sortedLogs = [...logs]; + if (typeof sortedLogs[0][key] === 'string') { + sortedLogs.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + } else { + sortedLogs.sort((a, b) => { + if (a[key] === b[key]) return 0; + if (a[key] > b[key]) return -1; + if (a[key] < b[key]) return 1; + }); + } + if (sortedLogs[0].id === logs[0].id) { + sortedLogs.reverse(); + } + setLogs(sortedLogs); + setLoading(false); + }; + + return ( + <> + +
+ 使用明细(总消耗额度: + {showStat && renderQuota(stat.quota)} + {!showStat && 点击查看} + ) +
+
+ + { + isAdminUser && ( + + ) + } + + + + + 查询 + +
+ + + + { + sortLog('created_time'); + }} + width={3} + > + 时间 + + { + isAdminUser && { + sortLog('username'); + }} + width={1} + > + 用户 + + } + { + sortLog('token_name'); + }} + width={1} + > + 令牌 + + { + sortLog('type'); + }} + width={1} + > + 类型 + + { + sortLog('model_name'); + }} + width={2} + > + 模型 + + { + sortLog('prompt_tokens'); + }} + width={1} + > + 提示 + + { + sortLog('completion_tokens'); + }} + width={1} + > + 补全 + + { + sortLog('quota'); + }} + width={2} + > + 消耗额度 + + { + sortLog('content'); + }} + width={isAdminUser ? 4 : 5} + > + 详情 + + + + + + {logs + .slice( + (activePage - 1) * ITEMS_PER_PAGE, + activePage * ITEMS_PER_PAGE + ) + .map((log, idx) => { + if (log.deleted) return <>; + return ( + + {renderTimestamp(log.created_at)} + { + isAdminUser && ( + {log.username ? : ''} + ) + } + {log.token_name ? : ''} + {renderType(log.type)} + {log.model_name ? : ''} + {log.prompt_tokens ? log.prompt_tokens : ''} + {log.completion_tokens ? log.completion_tokens : ''} + {log.quota ? renderQuota(log.quota, 6) : ''} + {log.content} + + ); + })} + + + + + +
+
+ + ); +}; + +export default LogsTable; diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js new file mode 100644 index 0000000000000000000000000000000000000000..2adc7fa4342a11b8e86c700c1c9e40bfdfea3f0d --- /dev/null +++ b/web/src/components/OperationSetting.js @@ -0,0 +1,331 @@ +import React, { useEffect, useState } from 'react'; +import { Divider, Form, Grid, Header } from 'semantic-ui-react'; +import { API, showError, verifyJSON } from '../helpers'; + +const OperationSetting = () => { + let [inputs, setInputs] = useState({ + QuotaForNewUser: 0, + QuotaForInviter: 0, + QuotaForInvitee: 0, + QuotaRemindThreshold: 0, + PreConsumedQuota: 0, + ModelRatio: '', + GroupRatio: '', + TopUpLink: '', + ChatLink: '', + QuotaPerUnit: 0, + AutomaticDisableChannelEnabled: '', + ChannelDisableThreshold: 0, + LogConsumeEnabled: '', + DisplayInCurrencyEnabled: '', + DisplayTokenStatEnabled: '', + ApproximateTokenEnabled: '', + RetryTimes: 0, + }); + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key === 'ModelRatio' || item.key === 'GroupRatio') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } + newInputs[item.key] = item.value; + }); + setInputs(newInputs); + setOriginInputs(newInputs); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + if (key.endsWith('Enabled')) { + value = inputs[key] === 'true' ? 'false' : 'true'; + } + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + setInputs((inputs) => ({ ...inputs, [key]: value })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + if (name.endsWith('Enabled')) { + await updateOption(name, value); + } else { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + }; + + const submitConfig = async (group) => { + switch (group) { + case 'monitor': + if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) { + await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold); + } + if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) { + await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold); + } + break; + case 'ratio': + if (originInputs['ModelRatio'] !== inputs.ModelRatio) { + if (!verifyJSON(inputs.ModelRatio)) { + showError('模型倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('ModelRatio', inputs.ModelRatio); + } + if (originInputs['GroupRatio'] !== inputs.GroupRatio) { + if (!verifyJSON(inputs.GroupRatio)) { + showError('分组倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('GroupRatio', inputs.GroupRatio); + } + break; + case 'quota': + if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { + await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); + } + if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) { + await updateOption('QuotaForInvitee', inputs.QuotaForInvitee); + } + if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) { + await updateOption('QuotaForInviter', inputs.QuotaForInviter); + } + if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) { + await updateOption('PreConsumedQuota', inputs.PreConsumedQuota); + } + break; + case 'general': + if (originInputs['TopUpLink'] !== inputs.TopUpLink) { + await updateOption('TopUpLink', inputs.TopUpLink); + } + if (originInputs['ChatLink'] !== inputs.ChatLink) { + await updateOption('ChatLink', inputs.ChatLink); + } + if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) { + await updateOption('QuotaPerUnit', inputs.QuotaPerUnit); + } + if (originInputs['RetryTimes'] !== inputs.RetryTimes) { + await updateOption('RetryTimes', inputs.RetryTimes); + } + break; + } + }; + + return ( + + +
+
+ 通用设置 +
+ + + + + + + + + + + + + { + submitConfig('general').then(); + }}>保存通用设置 + +
+ 监控设置 +
+ + + + + + + + { + submitConfig('monitor').then(); + }}>保存监控设置 + +
+ 额度设置 +
+ + + + + + + { + submitConfig('quota').then(); + }}>保存额度设置 + +
+ 倍率设置 +
+ + + + + + + { + submitConfig('ratio').then(); + }}>保存倍率设置 + +
+
+ ); +}; + +export default OperationSetting; diff --git a/web/src/components/OtherSetting.js b/web/src/components/OtherSetting.js new file mode 100644 index 0000000000000000000000000000000000000000..526a7d86872b0b40c13dcd1417bc010e64c8feea --- /dev/null +++ b/web/src/components/OtherSetting.js @@ -0,0 +1,207 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Divider, Form, Grid, Header, Message, Modal } from 'semantic-ui-react'; +import { API, showError, showSuccess } from '../helpers'; +import { marked } from 'marked'; + +const OtherSetting = () => { + let [inputs, setInputs] = useState({ + Footer: '', + Notice: '', + About: '', + SystemName: '', + Logo: '', + HomePageContent: '' + }); + let [loading, setLoading] = useState(false); + const [showUpdateModal, setShowUpdateModal] = useState(false); + const [updateData, setUpdateData] = useState({ + tag_name: '', + content: '' + }); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key in inputs) { + newInputs[item.key] = item.value; + } + }); + setInputs(newInputs); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + setInputs((inputs) => ({ ...inputs, [key]: value })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const submitNotice = async () => { + await updateOption('Notice', inputs.Notice); + }; + + const submitFooter = async () => { + await updateOption('Footer', inputs.Footer); + }; + + const submitSystemName = async () => { + await updateOption('SystemName', inputs.SystemName); + }; + + const submitLogo = async () => { + await updateOption('Logo', inputs.Logo); + }; + + const submitAbout = async () => { + await updateOption('About', inputs.About); + }; + + const submitOption = async (key) => { + await updateOption(key, inputs[key]); + }; + + const openGitHubRelease = () => { + window.location = + 'https://github.com/songquanpeng/one-api/releases/latest'; + }; + + const checkUpdate = async () => { + const res = await API.get( + 'https://api.github.com/repos/songquanpeng/one-api/releases/latest' + ); + const { tag_name, body } = res.data; + if (tag_name === process.env.REACT_APP_VERSION) { + showSuccess(`已是最新版本:${tag_name}`); + } else { + setUpdateData({ + tag_name: tag_name, + content: marked.parse(body) + }); + setShowUpdateModal(true); + } + }; + + return ( + + +
+
通用设置
+ 检查更新 + + + + 保存公告 + +
个性化设置
+ + + + 设置系统名称 + + + + 设置 Logo + + + + submitOption('HomePageContent')}>保存首页内容 + + + + 保存关于 + 移除 One API 的版权标识必须首先获得授权,项目维护需要花费大量精力,如果本项目对你有意义,请主动支持本项目。 + + + + 设置页脚 + +
+ setShowUpdateModal(false)} + onOpen={() => setShowUpdateModal(true)} + open={showUpdateModal} + > + 新版本:{updateData.tag_name} + + +
+
+
+ + + + + + +
+ ); +}; + +export default PasswordResetConfirm; diff --git a/web/src/components/PasswordResetForm.js b/web/src/components/PasswordResetForm.js new file mode 100644 index 0000000000000000000000000000000000000000..f3610d3a3ceb7264b681ac129ad3f42ceb13f85d --- /dev/null +++ b/web/src/components/PasswordResetForm.js @@ -0,0 +1,102 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react'; +import { API, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; + +const PasswordResetForm = () => { + const [inputs, setInputs] = useState({ + email: '' + }); + const { email } = inputs; + + const [loading, setLoading] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); + + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); + }, [disableButton, countdown]); + + function handleChange(e) { + const { name, value } = e.target; + setInputs(inputs => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + setDisableButton(true); + if (!email) return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/reset_password?email=${email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('重置邮件发送成功,请检查邮箱!'); + setInputs({ ...inputs, email: '' }); + } else { + showError(message); + } + setLoading(false); + } + + return ( + + +
+ 密码重置 +
+
+ + + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + +
+
+
+ ); +}; + +export default PasswordResetForm; diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js new file mode 100644 index 0000000000000000000000000000000000000000..c7a303f92f819418911b4464e77788084f0a8898 --- /dev/null +++ b/web/src/components/PersonalSetting.js @@ -0,0 +1,381 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react'; +import { Link, useNavigate } from 'react-router-dom'; +import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; +import { UserContext } from '../context/User'; + +const PersonalSetting = () => { + const [userState, userDispatch] = useContext(UserContext); + let navigate = useNavigate(); + + const [inputs, setInputs] = useState({ + wechat_verification_code: '', + email_verification_code: '', + email: '', + self_account_deletion_confirmation: '' + }); + const [status, setStatus] = useState({}); + const [showWeChatBindModal, setShowWeChatBindModal] = useState(false); + const [showEmailBindModal, setShowEmailBindModal] = useState(false); + const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [loading, setLoading] = useState(false); + const [disableButton, setDisableButton] = useState(false); + const [countdown, setCountdown] = useState(30); + const [affLink, setAffLink] = useState(""); + const [systemToken, setSystemToken] = useState(""); + + useEffect(() => { + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setStatus(status); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + }, []); + + useEffect(() => { + let countdownInterval = null; + if (disableButton && countdown > 0) { + countdownInterval = setInterval(() => { + setCountdown(countdown - 1); + }, 1000); + } else if (countdown === 0) { + setDisableButton(false); + setCountdown(30); + } + return () => clearInterval(countdownInterval); // Clean up on unmount + }, [disableButton, countdown]); + + const handleInputChange = (e, { name, value }) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + const generateAccessToken = async () => { + const res = await API.get('/api/user/token'); + const { success, message, data } = res.data; + if (success) { + setSystemToken(data); + setAffLink(""); + await copy(data); + showSuccess(`令牌已重置并已复制到剪贴板`); + } else { + showError(message); + } + }; + + const getAffLink = async () => { + const res = await API.get('/api/user/aff'); + const { success, message, data } = res.data; + if (success) { + let link = `${window.location.origin}/register?aff=${data}`; + setAffLink(link); + setSystemToken(""); + await copy(link); + showSuccess(`邀请链接已复制到剪切板`); + } else { + showError(message); + } + }; + + const handleAffLinkClick = async (e) => { + e.target.select(); + await copy(e.target.value); + showSuccess(`邀请链接已复制到剪切板`); + }; + + const handleSystemTokenClick = async (e) => { + e.target.select(); + await copy(e.target.value); + showSuccess(`系统令牌已复制到剪切板`); + }; + + const deleteAccount = async () => { + if (inputs.self_account_deletion_confirmation !== userState.user.username) { + showError('请输入你的账户名以确认删除!'); + return; + } + + const res = await API.delete('/api/user/self'); + const { success, message } = res.data; + + if (success) { + showSuccess('账户已删除!'); + await API.get('/api/user/logout'); + userDispatch({ type: 'logout' }); + localStorage.removeItem('user'); + navigate('/login'); + } else { + showError(message); + } + }; + + const bindWeChat = async () => { + if (inputs.wechat_verification_code === '') return; + const res = await API.get( + `/api/oauth/wechat/bind?code=${inputs.wechat_verification_code}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('微信账户绑定成功!'); + setShowWeChatBindModal(false); + } else { + showError(message); + } + }; + + const openGitHubOAuth = () => { + window.open( + `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` + ); + }; + + const sendVerificationCode = async () => { + setDisableButton(true); + if (inputs.email === '') return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('验证码发送成功,请检查邮箱!'); + } else { + showError(message); + } + setLoading(false); + }; + + const bindEmail = async () => { + if (inputs.email_verification_code === '') return; + setLoading(true); + const res = await API.get( + `/api/oauth/email/bind?email=${inputs.email}&code=${inputs.email_verification_code}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('邮箱账户绑定成功!'); + setShowEmailBindModal(false); + } else { + showError(message); + } + setLoading(false); + }; + + return ( +
+
通用设置
+ + 注意,此处生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。 + + + + + + + {systemToken && ( + + )} + {affLink && ( + + )} + +
账号绑定
+ { + status.wechat_login && ( + + ) + } + setShowWeChatBindModal(false)} + onOpen={() => setShowWeChatBindModal(true)} + open={showWeChatBindModal} + size={'mini'} + > + + + +
+

+ 微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效) +

+
+
+ + + +
+
+
+ { + status.github_oauth && ( + + ) + } + + setShowEmailBindModal(false)} + onOpen={() => setShowEmailBindModal(true)} + open={showEmailBindModal} + size={'tiny'} + style={{ maxWidth: '450px' }} + > + 绑定邮箱地址 + + +
+ + {disableButton ? `重新发送(${countdown})` : '获取验证码'} + + } + /> + + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+ +
+ +
+ +
+
+
+ setShowAccountDeleteModal(false)} + onOpen={() => setShowAccountDeleteModal(true)} + open={showAccountDeleteModal} + size={'tiny'} + style={{ maxWidth: '450px' }} + > + 危险操作 + + 您正在删除自己的帐户,将清空所有数据且不可恢复 + +
+ + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} +
+ +
+ +
+ +
+
+
+
+ ); +}; + +export default PersonalSetting; diff --git a/web/src/components/PrivateRoute.js b/web/src/components/PrivateRoute.js new file mode 100644 index 0000000000000000000000000000000000000000..f7cc7248996655b18d652fac0e9348d5c5782527 --- /dev/null +++ b/web/src/components/PrivateRoute.js @@ -0,0 +1,13 @@ +import { Navigate } from 'react-router-dom'; + +import { history } from '../helpers'; + + +function PrivateRoute({ children }) { + if (!localStorage.getItem('user')) { + return ; + } + return children; +} + +export { PrivateRoute }; \ No newline at end of file diff --git a/web/src/components/RedemptionsTable.js b/web/src/components/RedemptionsTable.js new file mode 100644 index 0000000000000000000000000000000000000000..ae8b5b03412ebb74973574db9781ca84932f2c7b --- /dev/null +++ b/web/src/components/RedemptionsTable.js @@ -0,0 +1,314 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Label, Popup, Pagination, Table } from 'semantic-ui-react'; +import { Link } from 'react-router-dom'; +import { API, copy, showError, showInfo, showSuccess, showWarning, timestamp2string } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderQuota } from '../helpers/render'; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +function renderStatus(status) { + switch (status) { + case 1: + return ; + case 2: + return ; + case 3: + return ; + default: + return ; + } +} + +const RedemptionsTable = () => { + const [redemptions, setRedemptions] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + + const loadRedemptions = async (startIdx) => { + const res = await API.get(`/api/redemption/?p=${startIdx}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setRedemptions(data); + } else { + let newRedemptions = redemptions; + newRedemptions.push(...data); + setRedemptions(newRedemptions); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(redemptions.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadRedemptions(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + useEffect(() => { + loadRedemptions(0) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const manageRedemption = async (id, action, idx) => { + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/redemption/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/redemption/?status_only=true', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/redemption/?status_only=true', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let redemption = res.data.data; + let newRedemptions = [...redemptions]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + newRedemptions[realIdx].deleted = true; + } else { + newRedemptions[realIdx].status = redemption.status; + } + setRedemptions(newRedemptions); + } else { + showError(message); + } + }; + + const searchRedemptions = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadRedemptions(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/redemption/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setRedemptions(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (e, { value }) => { + setSearchKeyword(value.trim()); + }; + + const sortRedemption = (key) => { + if (redemptions.length === 0) return; + setLoading(true); + let sortedRedemptions = [...redemptions]; + sortedRedemptions.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedRedemptions[0].id === redemptions[0].id) { + sortedRedemptions.reverse(); + } + setRedemptions(sortedRedemptions); + setLoading(false); + }; + + return ( + <> +
+ + + + + + + { + sortRedemption('id'); + }} + > + ID + + { + sortRedemption('name'); + }} + > + 名称 + + { + sortRedemption('status'); + }} + > + 状态 + + { + sortRedemption('quota'); + }} + > + 额度 + + { + sortRedemption('created_time'); + }} + > + 创建时间 + + { + sortRedemption('redeemed_time'); + }} + > + 兑换时间 + + 操作 + + + + + {redemptions + .slice( + (activePage - 1) * ITEMS_PER_PAGE, + activePage * ITEMS_PER_PAGE + ) + .map((redemption, idx) => { + if (redemption.deleted) return <>; + return ( + + {redemption.id} + {redemption.name ? redemption.name : '无'} + {renderStatus(redemption.status)} + {renderQuota(redemption.quota)} + {renderTimestamp(redemption.created_time)} + {redemption.redeemed_time ? renderTimestamp(redemption.redeemed_time) : "尚未兑换"} + +
+ + + 删除 + + } + on='click' + flowing + hoverable + > + + + + +
+
+
+ ); + })} +
+ + + + + + + + + +
+ + ); +}; + +export default RedemptionsTable; diff --git a/web/src/components/RegisterForm.js b/web/src/components/RegisterForm.js new file mode 100644 index 0000000000000000000000000000000000000000..f91d6da0905d3d66831021fc0e559a66380770aa --- /dev/null +++ b/web/src/components/RegisterForm.js @@ -0,0 +1,194 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Grid, Header, Image, Message, Segment } from 'semantic-ui-react'; +import { Link, useNavigate } from 'react-router-dom'; +import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; +import Turnstile from 'react-turnstile'; + +const RegisterForm = () => { + const [inputs, setInputs] = useState({ + username: '', + password: '', + password2: '', + email: '', + verification_code: '' + }); + const { username, password, password2 } = inputs; + const [showEmailVerification, setShowEmailVerification] = useState(false); + const [turnstileEnabled, setTurnstileEnabled] = useState(false); + const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); + const [turnstileToken, setTurnstileToken] = useState(''); + const [loading, setLoading] = useState(false); + const logo = getLogo(); + let affCode = new URLSearchParams(window.location.search).get('aff'); + if (affCode) { + localStorage.setItem('aff', affCode); + } + + useEffect(() => { + let status = localStorage.getItem('status'); + if (status) { + status = JSON.parse(status); + setShowEmailVerification(status.email_verification); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + } + }); + + let navigate = useNavigate(); + + function handleChange(e) { + const { name, value } = e.target; + console.log(name, value); + setInputs((inputs) => ({ ...inputs, [name]: value })); + } + + async function handleSubmit(e) { + if (password.length < 8) { + showInfo('密码长度不得小于 8 位!'); + return; + } + if (password !== password2) { + showInfo('两次输入的密码不一致'); + return; + } + if (username && password) { + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + if (!affCode) { + affCode = localStorage.getItem('aff'); + } + inputs.aff_code = affCode; + const res = await API.post( + `/api/user/register?turnstile=${turnstileToken}`, + inputs + ); + const { success, message } = res.data; + if (success) { + navigate('/login'); + showSuccess('注册成功!'); + } else { + showError(message); + } + setLoading(false); + } + } + + const sendVerificationCode = async () => { + if (inputs.email === '') return; + if (turnstileEnabled && turnstileToken === '') { + showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); + return; + } + setLoading(true); + const res = await API.get( + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` + ); + const { success, message } = res.data; + if (success) { + showSuccess('验证码发送成功,请检查你的邮箱!'); + } else { + showError(message); + } + setLoading(false); + }; + + return ( + + +
+ 新用户注册 +
+
+ + + + + {showEmailVerification ? ( + <> + + 获取验证码 + + } + /> + + + ) : ( + <> + )} + {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} + /> + ) : ( + <> + )} + + +
+ + 已有账户? + + 点击登录 + + +
+
+ ); +}; + +export default RegisterForm; diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js new file mode 100644 index 0000000000000000000000000000000000000000..7b34ce5b2dd3a2eef48ea3d8c89284567ba1a936 --- /dev/null +++ b/web/src/components/SystemSetting.js @@ -0,0 +1,537 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Divider, Form, Grid, Header, Modal, Message } from 'semantic-ui-react'; +import { API, removeTrailingSlash, showError } from '../helpers'; + +const SystemSetting = () => { + let [inputs, setInputs] = useState({ + PasswordLoginEnabled: '', + PasswordRegisterEnabled: '', + EmailVerificationEnabled: '', + GitHubOAuthEnabled: '', + GitHubClientId: '', + GitHubClientSecret: '', + Notice: '', + SMTPServer: '', + SMTPPort: '', + SMTPAccount: '', + SMTPFrom: '', + SMTPToken: '', + ServerAddress: '', + Footer: '', + WeChatAuthEnabled: '', + WeChatServerAddress: '', + WeChatServerToken: '', + WeChatAccountQRCodeImageURL: '', + TurnstileCheckEnabled: '', + TurnstileSiteKey: '', + TurnstileSecretKey: '', + RegisterEnabled: '', + EmailDomainRestrictionEnabled: '', + EmailDomainWhitelist: '' + }); + const [originInputs, setOriginInputs] = useState({}); + let [loading, setLoading] = useState(false); + const [EmailDomainWhitelist, setEmailDomainWhitelist] = useState([]); + const [restrictedDomainInput, setRestrictedDomainInput] = useState(''); + const [showPasswordWarningModal, setShowPasswordWarningModal] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + newInputs[item.key] = item.value; + }); + setInputs({ + ...newInputs, + EmailDomainWhitelist: newInputs.EmailDomainWhitelist.split(',') + }); + setOriginInputs(newInputs); + + setEmailDomainWhitelist(newInputs.EmailDomainWhitelist.split(',').map((item) => { + return { key: item, text: item, value: item }; + })); + } else { + showError(message); + } + }; + + useEffect(() => { + getOptions().then(); + }, []); + + const updateOption = async (key, value) => { + setLoading(true); + switch (key) { + case 'PasswordLoginEnabled': + case 'PasswordRegisterEnabled': + case 'EmailVerificationEnabled': + case 'GitHubOAuthEnabled': + case 'WeChatAuthEnabled': + case 'TurnstileCheckEnabled': + case 'EmailDomainRestrictionEnabled': + case 'RegisterEnabled': + value = inputs[key] === 'true' ? 'false' : 'true'; + break; + default: + break; + } + const res = await API.put('/api/option/', { + key, + value + }); + const { success, message } = res.data; + if (success) { + if (key === 'EmailDomainWhitelist') { + value = value.split(','); + } + setInputs((inputs) => ({ + ...inputs, [key]: value + })); + } else { + showError(message); + } + setLoading(false); + }; + + const handleInputChange = async (e, { name, value }) => { + if (name === 'PasswordLoginEnabled' && inputs[name] === 'true') { + // block disabling password login + setShowPasswordWarningModal(true); + return; + } + if ( + name === 'Notice' || + name.startsWith('SMTP') || + name === 'ServerAddress' || + name === 'GitHubClientId' || + name === 'GitHubClientSecret' || + name === 'WeChatServerAddress' || + name === 'WeChatServerToken' || + name === 'WeChatAccountQRCodeImageURL' || + name === 'TurnstileSiteKey' || + name === 'TurnstileSecretKey' || + name === 'EmailDomainWhitelist' + ) { + setInputs((inputs) => ({ ...inputs, [name]: value })); + } else { + await updateOption(name, value); + } + }; + + const submitServerAddress = async () => { + let ServerAddress = removeTrailingSlash(inputs.ServerAddress); + await updateOption('ServerAddress', ServerAddress); + }; + + const submitSMTP = async () => { + if (originInputs['SMTPServer'] !== inputs.SMTPServer) { + await updateOption('SMTPServer', inputs.SMTPServer); + } + if (originInputs['SMTPAccount'] !== inputs.SMTPAccount) { + await updateOption('SMTPAccount', inputs.SMTPAccount); + } + if (originInputs['SMTPFrom'] !== inputs.SMTPFrom) { + await updateOption('SMTPFrom', inputs.SMTPFrom); + } + if ( + originInputs['SMTPPort'] !== inputs.SMTPPort && + inputs.SMTPPort !== '' + ) { + await updateOption('SMTPPort', inputs.SMTPPort); + } + if ( + originInputs['SMTPToken'] !== inputs.SMTPToken && + inputs.SMTPToken !== '' + ) { + await updateOption('SMTPToken', inputs.SMTPToken); + } + }; + + + const submitEmailDomainWhitelist = async () => { + if ( + originInputs['EmailDomainWhitelist'] !== inputs.EmailDomainWhitelist.join(',') && + inputs.SMTPToken !== '' + ) { + await updateOption('EmailDomainWhitelist', inputs.EmailDomainWhitelist.join(',')); + } + }; + + const submitWeChat = async () => { + if (originInputs['WeChatServerAddress'] !== inputs.WeChatServerAddress) { + await updateOption( + 'WeChatServerAddress', + removeTrailingSlash(inputs.WeChatServerAddress) + ); + } + if ( + originInputs['WeChatAccountQRCodeImageURL'] !== + inputs.WeChatAccountQRCodeImageURL + ) { + await updateOption( + 'WeChatAccountQRCodeImageURL', + inputs.WeChatAccountQRCodeImageURL + ); + } + if ( + originInputs['WeChatServerToken'] !== inputs.WeChatServerToken && + inputs.WeChatServerToken !== '' + ) { + await updateOption('WeChatServerToken', inputs.WeChatServerToken); + } + }; + + const submitGitHubOAuth = async () => { + if (originInputs['GitHubClientId'] !== inputs.GitHubClientId) { + await updateOption('GitHubClientId', inputs.GitHubClientId); + } + if ( + originInputs['GitHubClientSecret'] !== inputs.GitHubClientSecret && + inputs.GitHubClientSecret !== '' + ) { + await updateOption('GitHubClientSecret', inputs.GitHubClientSecret); + } + }; + + const submitTurnstile = async () => { + if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) { + await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey); + } + if ( + originInputs['TurnstileSecretKey'] !== inputs.TurnstileSecretKey && + inputs.TurnstileSecretKey !== '' + ) { + await updateOption('TurnstileSecretKey', inputs.TurnstileSecretKey); + } + }; + + const submitNewRestrictedDomain = () => { + const localDomainList = inputs.EmailDomainWhitelist; + if (restrictedDomainInput !== '' && !localDomainList.includes(restrictedDomainInput)) { + setRestrictedDomainInput(''); + setInputs({ + ...inputs, + EmailDomainWhitelist: [...localDomainList, restrictedDomainInput], + }); + setEmailDomainWhitelist([...EmailDomainWhitelist, { + key: restrictedDomainInput, + text: restrictedDomainInput, + value: restrictedDomainInput, + }]); + } + } + + return ( + + +
+
通用设置
+ + + + + 更新服务器地址 + + +
配置登录注册
+ + + { + showPasswordWarningModal && + setShowPasswordWarningModal(false)} + size={'tiny'} + style={{ maxWidth: '450px' }} + > + 警告 + +

取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?

+
+ + + + +
+ } + + + + +
+ + + + + +
+ 配置邮箱域名白名单 + 用以防止恶意用户利用临时邮箱批量注册 +
+ + + + + + { + submitNewRestrictedDomain(); + }}>填入 + } + onKeyDown={(e) => { + if (e.key === 'Enter') { + submitNewRestrictedDomain(); + } + }} + autoComplete='new-password' + placeholder='输入新的允许的邮箱域名' + value={restrictedDomainInput} + onChange={(e, { value }) => { + setRestrictedDomainInput(value); + }} + /> + + 保存邮箱域名白名单设置 + +
+ 配置 SMTP + 用以支持系统的邮件发送 +
+ + + + + + + + + + 保存 SMTP 设置 + +
+ 配置 GitHub OAuth App + + 用以支持通过 GitHub 进行登录注册, + + 点击此处 + + 管理你的 GitHub OAuth App + +
+ + Homepage URL 填 {inputs.ServerAddress} + ,Authorization callback URL 填{' '} + {`${inputs.ServerAddress}/oauth/github`} + + + + + + + 保存 GitHub OAuth 设置 + + +
+ 配置 WeChat Server + + 用以支持通过微信进行登录注册, + + 点击此处 + + 了解 WeChat Server + +
+ + + + + + + 保存 WeChat Server 设置 + + +
+ 配置 Turnstile + + 用以支持用户校验, + + 点击此处 + + 管理你的 Turnstile Sites,推荐选择 Invisible Widget Type + +
+ + + + + + 保存 Turnstile 设置 + + +
+
+ ); +}; + +export default SystemSetting; diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js new file mode 100644 index 0000000000000000000000000000000000000000..b45f07dfdd4f325e5cab5985ea19c6b115379171 --- /dev/null +++ b/web/src/components/TokensTable.js @@ -0,0 +1,443 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Dropdown, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; +import { Link } from 'react-router-dom'; +import { API, copy, showError, showSuccess, showWarning, timestamp2string } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderQuota } from '../helpers/render'; + +const COPY_OPTIONS = [ + { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, + { key: 'ama', text: 'AMA 问天', value: 'ama' }, + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, +]; + +const OPEN_LINK_OPTIONS = [ + { key: 'ama', text: 'AMA 问天', value: 'ama' }, + { key: 'opencat', text: 'OpenCat', value: 'opencat' }, +]; + +function renderTimestamp(timestamp) { + return ( + <> + {timestamp2string(timestamp)} + + ); +} + +function renderStatus(status) { + switch (status) { + case 1: + return ; + case 2: + return ; + case 3: + return ; + case 4: + return ; + default: + return ; + } +} + +const TokensTable = () => { + const [tokens, setTokens] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + const [showTopUpModal, setShowTopUpModal] = useState(false); + const [targetTokenIdx, setTargetTokenIdx] = useState(0); + + const loadTokens = async (startIdx) => { + const res = await API.get(`/api/token/?p=${startIdx}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setTokens(data); + } else { + let newTokens = [...tokens]; + newTokens.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); + setTokens(newTokens); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(tokens.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadTokens(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + const refresh = async () => { + setLoading(true); + await loadTokens(activePage - 1); + }; + + const onCopy = async (type, key) => { + let status = localStorage.getItem('status'); + let serverAddress = ''; + if (status) { + status = JSON.parse(status); + serverAddress = status.server_address; + } + if (serverAddress === '') { + serverAddress = window.location.origin; + } + let encodedServerAddress = encodeURIComponent(serverAddress); + const nextLink = localStorage.getItem('chat_link'); + let nextUrl; + + if (nextLink) { + nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`; + } else { + nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } + + let url; + switch (type) { + case 'ama': + url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; + break; + case 'opencat': + url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; + break; + case 'next': + url = nextUrl; + break; + default: + url = `sk-${key}`; + } + if (await copy(url)) { + showSuccess('已复制到剪贴板!'); + } else { + showWarning('无法复制到剪贴板,请手动复制,已将令牌填入搜索框。'); + setSearchKeyword(url); + } + }; + + const onOpenLink = async (type, key) => { + let status = localStorage.getItem('status'); + let serverAddress = ''; + if (status) { + status = JSON.parse(status); + serverAddress = status.server_address; + } + if (serverAddress === '') { + serverAddress = window.location.origin; + } + let encodedServerAddress = encodeURIComponent(serverAddress); + const chatLink = localStorage.getItem('chat_link'); + let defaultUrl; + + if (chatLink) { + defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`; + } else { + defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; + } + let url; + switch (type) { + case 'ama': + url = `ama://set-api-key?server=${encodedServerAddress}&key=sk-${key}`; + break; + + case 'opencat': + url = `opencat://team/join?domain=${encodedServerAddress}&token=sk-${key}`; + break; + + default: + url = defaultUrl; + } + + window.open(url, '_blank'); + } + + useEffect(() => { + loadTokens(0) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const manageToken = async (id, action, idx) => { + let data = { id }; + let res; + switch (action) { + case 'delete': + res = await API.delete(`/api/token/${id}/`); + break; + case 'enable': + data.status = 1; + res = await API.put('/api/token/?status_only=true', data); + break; + case 'disable': + data.status = 2; + res = await API.put('/api/token/?status_only=true', data); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let token = res.data.data; + let newTokens = [...tokens]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + newTokens[realIdx].deleted = true; + } else { + newTokens[realIdx].status = token.status; + } + setTokens(newTokens); + } else { + showError(message); + } + }; + + const searchTokens = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadTokens(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/token/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setTokens(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (e, { value }) => { + setSearchKeyword(value.trim()); + }; + + const sortToken = (key) => { + if (tokens.length === 0) return; + setLoading(true); + let sortedTokens = [...tokens]; + sortedTokens.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedTokens[0].id === tokens[0].id) { + sortedTokens.reverse(); + } + setTokens(sortedTokens); + setLoading(false); + }; + + return ( + <> +
+ + + + + + + { + sortToken('name'); + }} + > + 名称 + + { + sortToken('status'); + }} + > + 状态 + + { + sortToken('used_quota'); + }} + > + 已用额度 + + { + sortToken('remain_quota'); + }} + > + 剩余额度 + + { + sortToken('created_time'); + }} + > + 创建时间 + + { + sortToken('expired_time'); + }} + > + 过期时间 + + 操作 + + + + + {tokens + .slice( + (activePage - 1) * ITEMS_PER_PAGE, + activePage * ITEMS_PER_PAGE + ) + .map((token, idx) => { + if (token.deleted) return <>; + return ( + + {token.name ? token.name : '无'} + {renderStatus(token.status)} + {renderQuota(token.used_quota)} + {token.unlimited_quota ? '无限制' : renderQuota(token.remain_quota, 2)} + {renderTimestamp(token.created_time)} + {token.expired_time === -1 ? '永不过期' : renderTimestamp(token.expired_time)} + +
+ + + ({ + ...option, + onClick: async () => { + await onCopy(option.value, token.key); + } + }))} + trigger={<>} + /> + + {' '} + + + ({ + ...option, + onClick: async () => { + await onOpenLink(option.value, token.key); + } + }))} + trigger={<>} + /> + + {' '} + + 删除 + + } + on='click' + flowing + hoverable + > + + + + +
+
+
+ ); + })} +
+ + + + + + + + + + +
+ + ); +}; + +export default TokensTable; diff --git a/web/src/components/UsersTable.js b/web/src/components/UsersTable.js new file mode 100644 index 0000000000000000000000000000000000000000..f8fb0a75b329abee6c7c9ec8a52c5d2ae6c1b465 --- /dev/null +++ b/web/src/components/UsersTable.js @@ -0,0 +1,338 @@ +import React, { useEffect, useState } from 'react'; +import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react'; +import { Link } from 'react-router-dom'; +import { API, showError, showSuccess } from '../helpers'; + +import { ITEMS_PER_PAGE } from '../constants'; +import { renderGroup, renderNumber, renderQuota, renderText } from '../helpers/render'; + +function renderRole(role) { + switch (role) { + case 1: + return ; + case 10: + return ; + case 100: + return ; + default: + return ; + } +} + +const UsersTable = () => { + const [users, setUsers] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [searchKeyword, setSearchKeyword] = useState(''); + const [searching, setSearching] = useState(false); + + const loadUsers = async (startIdx) => { + const res = await API.get(`/api/user/?p=${startIdx}`); + const { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setUsers(data); + } else { + let newUsers = users; + newUsers.push(...data); + setUsers(newUsers); + } + } else { + showError(message); + } + setLoading(false); + }; + + const onPaginationChange = (e, { activePage }) => { + (async () => { + if (activePage === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + await loadUsers(activePage - 1); + } + setActivePage(activePage); + })(); + }; + + useEffect(() => { + loadUsers(0) + .then() + .catch((reason) => { + showError(reason); + }); + }, []); + + const manageUser = (username, action, idx) => { + (async () => { + const res = await API.post('/api/user/manage', { + username, + action + }); + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + let user = res.data.data; + let newUsers = [...users]; + let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; + if (action === 'delete') { + newUsers[realIdx].deleted = true; + } else { + newUsers[realIdx].status = user.status; + newUsers[realIdx].role = user.role; + } + setUsers(newUsers); + } else { + showError(message); + } + })(); + }; + + const renderStatus = (status) => { + switch (status) { + case 1: + return ; + case 2: + return ( + + ); + default: + return ( + + ); + } + }; + + const searchUsers = async () => { + if (searchKeyword === '') { + // if keyword is blank, load files instead. + await loadUsers(0); + setActivePage(1); + return; + } + setSearching(true); + const res = await API.get(`/api/user/search?keyword=${searchKeyword}`); + const { success, message, data } = res.data; + if (success) { + setUsers(data); + setActivePage(1); + } else { + showError(message); + } + setSearching(false); + }; + + const handleKeywordChange = async (e, { value }) => { + setSearchKeyword(value.trim()); + }; + + const sortUser = (key) => { + if (users.length === 0) return; + setLoading(true); + let sortedUsers = [...users]; + sortedUsers.sort((a, b) => { + return ('' + a[key]).localeCompare(b[key]); + }); + if (sortedUsers[0].id === users[0].id) { + sortedUsers.reverse(); + } + setUsers(sortedUsers); + setLoading(false); + }; + + return ( + <> +
+ + + + + + + { + sortUser('id'); + }} + > + ID + + { + sortUser('username'); + }} + > + 用户名 + + { + sortUser('group'); + }} + > + 分组 + + { + sortUser('quota'); + }} + > + 统计信息 + + { + sortUser('role'); + }} + > + 用户角色 + + { + sortUser('status'); + }} + > + 状态 + + 操作 + + + + + {users + .slice( + (activePage - 1) * ITEMS_PER_PAGE, + activePage * ITEMS_PER_PAGE + ) + .map((user, idx) => { + if (user.deleted) return <>; + return ( + + {user.id} + + {renderText(user.username, 15)}} + hoverable + /> + + {renderGroup(user.group)} + {/**/} + {/* {user.email ? {renderText(user.email, 24)}} /> : '无'}*/} + {/**/} + + {renderQuota(user.quota)}} /> + {renderQuota(user.used_quota)}} /> + {renderNumber(user.request_count)}} /> + + {renderRole(user.role)} + {renderStatus(user.status)} + +
+ + + + 删除 + + } + on='click' + flowing + hoverable + > + + + + +
+
+
+ ); + })} +
+ + + + + + + + + +
+ + ); +}; + +export default UsersTable; diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js new file mode 100644 index 0000000000000000000000000000000000000000..b163147974655f88fe9f53933f3a44ecc8274597 --- /dev/null +++ b/web/src/constants/channel.constants.js @@ -0,0 +1,22 @@ +export const CHANNEL_OPTIONS = [ + { key: 1, text: 'OpenAI', value: 1, color: 'green' }, + { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, + { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, + { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, + { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, + { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, + { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, + { key: 19, text: '360 智脑', value: 19, color: 'blue' }, + { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, + { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, + { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, + { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, + { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, + { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, + { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, + { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, + { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, + { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } +]; \ No newline at end of file diff --git a/web/src/constants/common.constant.js b/web/src/constants/common.constant.js new file mode 100644 index 0000000000000000000000000000000000000000..1a37d5f6f4f81a8fa35e91bda97b70928f75a413 --- /dev/null +++ b/web/src/constants/common.constant.js @@ -0,0 +1 @@ +export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend! diff --git a/web/src/constants/index.js b/web/src/constants/index.js new file mode 100644 index 0000000000000000000000000000000000000000..e83152bcfecd6c0c9ed4fd87a4e11b5c059481c5 --- /dev/null +++ b/web/src/constants/index.js @@ -0,0 +1,4 @@ +export * from './toast.constants'; +export * from './user.constants'; +export * from './common.constant'; +export * from './channel.constants'; \ No newline at end of file diff --git a/web/src/constants/toast.constants.js b/web/src/constants/toast.constants.js new file mode 100644 index 0000000000000000000000000000000000000000..50684722c52085d56cb0a7e9671517fa403d68ea --- /dev/null +++ b/web/src/constants/toast.constants.js @@ -0,0 +1,7 @@ +export const toastConstants = { + SUCCESS_TIMEOUT: 1500, + INFO_TIMEOUT: 3000, + ERROR_TIMEOUT: 5000, + WARNING_TIMEOUT: 10000, + NOTICE_TIMEOUT: 20000 +}; diff --git a/web/src/constants/user.constants.js b/web/src/constants/user.constants.js new file mode 100644 index 0000000000000000000000000000000000000000..2680d8ef1ef5fe6e4900d233d035a926527de81b --- /dev/null +++ b/web/src/constants/user.constants.js @@ -0,0 +1,19 @@ +export const userConstants = { + REGISTER_REQUEST: 'USERS_REGISTER_REQUEST', + REGISTER_SUCCESS: 'USERS_REGISTER_SUCCESS', + REGISTER_FAILURE: 'USERS_REGISTER_FAILURE', + + LOGIN_REQUEST: 'USERS_LOGIN_REQUEST', + LOGIN_SUCCESS: 'USERS_LOGIN_SUCCESS', + LOGIN_FAILURE: 'USERS_LOGIN_FAILURE', + + LOGOUT: 'USERS_LOGOUT', + + GETALL_REQUEST: 'USERS_GETALL_REQUEST', + GETALL_SUCCESS: 'USERS_GETALL_SUCCESS', + GETALL_FAILURE: 'USERS_GETALL_FAILURE', + + DELETE_REQUEST: 'USERS_DELETE_REQUEST', + DELETE_SUCCESS: 'USERS_DELETE_SUCCESS', + DELETE_FAILURE: 'USERS_DELETE_FAILURE' +}; diff --git a/web/src/context/Status/index.js b/web/src/context/Status/index.js new file mode 100644 index 0000000000000000000000000000000000000000..71f0682b63fe43ab696a4a1f90b654b139de1a18 --- /dev/null +++ b/web/src/context/Status/index.js @@ -0,0 +1,19 @@ +// contexts/User/index.jsx + +import React from 'react'; +import { initialState, reducer } from './reducer'; + +export const StatusContext = React.createContext({ + state: initialState, + dispatch: () => null, +}); + +export const StatusProvider = ({ children }) => { + const [state, dispatch] = React.useReducer(reducer, initialState); + + return ( + + {children} + + ); +}; \ No newline at end of file diff --git a/web/src/context/Status/reducer.js b/web/src/context/Status/reducer.js new file mode 100644 index 0000000000000000000000000000000000000000..ec9ac6ae667a90353cd088cb2df04b0413a16239 --- /dev/null +++ b/web/src/context/Status/reducer.js @@ -0,0 +1,20 @@ +export const reducer = (state, action) => { + switch (action.type) { + case 'set': + return { + ...state, + status: action.payload, + }; + case 'unset': + return { + ...state, + status: undefined, + }; + default: + return state; + } +}; + +export const initialState = { + status: undefined, +}; diff --git a/web/src/context/User/index.js b/web/src/context/User/index.js new file mode 100644 index 0000000000000000000000000000000000000000..c6671591982466262d252e9c5c5ddc14591cb365 --- /dev/null +++ b/web/src/context/User/index.js @@ -0,0 +1,19 @@ +// contexts/User/index.jsx + +import React from "react" +import { reducer, initialState } from "./reducer" + +export const UserContext = React.createContext({ + state: initialState, + dispatch: () => null +}) + +export const UserProvider = ({ children }) => { + const [state, dispatch] = React.useReducer(reducer, initialState) + + return ( + + { children } + + ) +} \ No newline at end of file diff --git a/web/src/context/User/reducer.js b/web/src/context/User/reducer.js new file mode 100644 index 0000000000000000000000000000000000000000..9ed1d809af6e5a33687026d75ad6788cea8a02c4 --- /dev/null +++ b/web/src/context/User/reducer.js @@ -0,0 +1,21 @@ +export const reducer = (state, action) => { + switch (action.type) { + case 'login': + return { + ...state, + user: action.payload + }; + case 'logout': + return { + ...state, + user: undefined + }; + + default: + return state; + } +}; + +export const initialState = { + user: undefined +}; \ No newline at end of file diff --git a/web/src/helpers/api.js b/web/src/helpers/api.js new file mode 100644 index 0000000000000000000000000000000000000000..35fdb1e95f7f623c27e452a6946c7224501842da --- /dev/null +++ b/web/src/helpers/api.js @@ -0,0 +1,13 @@ +import { showError } from './utils'; +import axios from 'axios'; + +export const API = axios.create({ + baseURL: process.env.REACT_APP_SERVER ? process.env.REACT_APP_SERVER : '', +}); + +API.interceptors.response.use( + (response) => response, + (error) => { + showError(error); + } +); diff --git a/web/src/helpers/auth-header.js b/web/src/helpers/auth-header.js new file mode 100644 index 0000000000000000000000000000000000000000..a8fe5f5a7fa19e9213c208a7f7077b16401fcfd2 --- /dev/null +++ b/web/src/helpers/auth-header.js @@ -0,0 +1,10 @@ +export function authHeader() { + // return authorization header with jwt token + let user = JSON.parse(localStorage.getItem('user')); + + if (user && user.token) { + return { 'Authorization': 'Bearer ' + user.token }; + } else { + return {}; + } +} \ No newline at end of file diff --git a/web/src/helpers/history.js b/web/src/helpers/history.js new file mode 100644 index 0000000000000000000000000000000000000000..629039e5affe8208e1154b67708101da8910e0f9 --- /dev/null +++ b/web/src/helpers/history.js @@ -0,0 +1,3 @@ +import { createBrowserHistory } from 'history'; + +export const history = createBrowserHistory(); \ No newline at end of file diff --git a/web/src/helpers/index.js b/web/src/helpers/index.js new file mode 100644 index 0000000000000000000000000000000000000000..505a8cf9166a78c457f51b079ae06994775e6fe9 --- /dev/null +++ b/web/src/helpers/index.js @@ -0,0 +1,4 @@ +export * from './history'; +export * from './auth-header'; +export * from './utils'; +export * from './api'; \ No newline at end of file diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js new file mode 100644 index 0000000000000000000000000000000000000000..a9c81cc11b5889e87d9e04d03b25b5cf3e5f1ac1 --- /dev/null +++ b/web/src/helpers/render.js @@ -0,0 +1,58 @@ +import { Label } from 'semantic-ui-react'; + +export function renderText(text, limit) { + if (text.length > limit) { + return text.slice(0, limit - 3) + '...'; + } + return text; +} + +export function renderGroup(group) { + if (group === '') { + return ; + } + let groups = group.split(','); + groups.sort(); + return <> + {groups.map((group) => { + if (group === 'vip' || group === 'pro') { + return ; + } else if (group === 'svip' || group === 'premium') { + return ; + } + return ; + })} + ; +} + +export function renderNumber(num) { + if (num >= 1000000000) { + return (num / 1000000000).toFixed(1) + 'B'; + } else if (num >= 1000000) { + return (num / 1000000).toFixed(1) + 'M'; + } else if (num >= 10000) { + return (num / 1000).toFixed(1) + 'k'; + } else { + return num; + } +} + +export function renderQuota(quota, digits = 2) { + let quotaPerUnit = localStorage.getItem('quota_per_unit'); + let displayInCurrency = localStorage.getItem('display_in_currency'); + quotaPerUnit = parseFloat(quotaPerUnit); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return '$' + (quota / quotaPerUnit).toFixed(digits); + } + return renderNumber(quota); +} + +export function renderQuotaWithPrompt(quota, digits) { + let displayInCurrency = localStorage.getItem('display_in_currency'); + displayInCurrency = displayInCurrency === 'true'; + if (displayInCurrency) { + return `(等价金额:${renderQuota(quota, digits)})`; + } + return ''; +} \ No newline at end of file diff --git a/web/src/helpers/utils.js b/web/src/helpers/utils.js new file mode 100644 index 0000000000000000000000000000000000000000..3871a43ec0be203b61c666ae1580169bbe392212 --- /dev/null +++ b/web/src/helpers/utils.js @@ -0,0 +1,189 @@ +import { toast } from 'react-toastify'; +import { toastConstants } from '../constants'; +import React from 'react'; + +const HTMLToastContent = ({ htmlContent }) => { + return
; +}; +export default HTMLToastContent; +export function isAdmin() { + let user = localStorage.getItem('user'); + if (!user) return false; + user = JSON.parse(user); + return user.role >= 10; +} + +export function isRoot() { + let user = localStorage.getItem('user'); + if (!user) return false; + user = JSON.parse(user); + return user.role >= 100; +} + +export function getSystemName() { + let system_name = localStorage.getItem('system_name'); + if (!system_name) return 'One API'; + return system_name; +} + +export function getLogo() { + let logo = localStorage.getItem('logo'); + if (!logo) return '/logo.png'; + return logo +} + +export function getFooterHTML() { + return localStorage.getItem('footer_html'); +} + +export async function copy(text) { + let okay = true; + try { + await navigator.clipboard.writeText(text); + } catch (e) { + okay = false; + console.error(e); + } + return okay; +} + +export function isMobile() { + return window.innerWidth <= 600; +} + +let showErrorOptions = { autoClose: toastConstants.ERROR_TIMEOUT }; +let showWarningOptions = { autoClose: toastConstants.WARNING_TIMEOUT }; +let showSuccessOptions = { autoClose: toastConstants.SUCCESS_TIMEOUT }; +let showInfoOptions = { autoClose: toastConstants.INFO_TIMEOUT }; +let showNoticeOptions = { autoClose: false }; + +if (isMobile()) { + showErrorOptions.position = 'top-center'; + // showErrorOptions.transition = 'flip'; + + showSuccessOptions.position = 'top-center'; + // showSuccessOptions.transition = 'flip'; + + showInfoOptions.position = 'top-center'; + // showInfoOptions.transition = 'flip'; + + showNoticeOptions.position = 'top-center'; + // showNoticeOptions.transition = 'flip'; +} + +export function showError(error) { + console.error(error); + if (error.message) { + if (error.name === 'AxiosError') { + switch (error.response.status) { + case 401: + // toast.error('错误:未登录或登录已过期,请重新登录!', showErrorOptions); + window.location.href = '/login?expired=true'; + break; + case 429: + toast.error('错误:请求次数过多,请稍后再试!', showErrorOptions); + break; + case 500: + toast.error('错误:服务器内部错误,请联系管理员!', showErrorOptions); + break; + case 405: + toast.info('本站仅作演示之用,无服务端!'); + break; + default: + toast.error('错误:' + error.message, showErrorOptions); + } + return; + } + toast.error('错误:' + error.message, showErrorOptions); + } else { + toast.error('错误:' + error, showErrorOptions); + } +} + +export function showWarning(message) { + toast.warn(message, showWarningOptions); +} + +export function showSuccess(message) { + toast.success(message, showSuccessOptions); +} + +export function showInfo(message) { + toast.info(message, showInfoOptions); +} + +export function showNotice(message, isHTML = false) { + if (isHTML) { + toast(, showNoticeOptions); + } else { + toast.info(message, showNoticeOptions); + } +} + +export function openPage(url) { + window.open(url); +} + +export function removeTrailingSlash(url) { + if (url.endsWith('/')) { + return url.slice(0, -1); + } else { + return url; + } +} + +export function timestamp2string(timestamp) { + let date = new Date(timestamp * 1000); + let year = date.getFullYear().toString(); + let month = (date.getMonth() + 1).toString(); + let day = date.getDate().toString(); + let hour = date.getHours().toString(); + let minute = date.getMinutes().toString(); + let second = date.getSeconds().toString(); + if (month.length === 1) { + month = '0' + month; + } + if (day.length === 1) { + day = '0' + day; + } + if (hour.length === 1) { + hour = '0' + hour; + } + if (minute.length === 1) { + minute = '0' + minute; + } + if (second.length === 1) { + second = '0' + second; + } + return ( + year + + '-' + + month + + '-' + + day + + ' ' + + hour + + ':' + + minute + + ':' + + second + ); +} + +export function downloadTextAsFile(text, filename) { + let blob = new Blob([text], { type: 'text/plain;charset=utf-8' }); + let url = URL.createObjectURL(blob); + let a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); +} + +export const verifyJSON = (str) => { + try { + JSON.parse(str); + } catch (e) { + return false; + } + return true; +}; \ No newline at end of file diff --git a/web/src/index.css b/web/src/index.css new file mode 100644 index 0000000000000000000000000000000000000000..5d60e377358899f370a2f7405380ab149dbba661 --- /dev/null +++ b/web/src/index.css @@ -0,0 +1,35 @@ +body { + margin: 0; + padding-top: 55px; + overflow-y: scroll; + font-family: Lato, 'Helvetica Neue', Arial, Helvetica, "Microsoft YaHei", sans-serif; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + scrollbar-width: none; +} + +body::-webkit-scrollbar { + display: none; +} + +code { + font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace; +} + +.main-content { + padding: 4px; +} + +.small-icon .icon { + font-size: 1em !important; +} + +.custom-footer { + font-size: 1.1em; +} + +@media only screen and (max-width: 600px) { + .hide-on-mobile { + display: none !important; + } +} diff --git a/web/src/index.js b/web/src/index.js new file mode 100644 index 0000000000000000000000000000000000000000..eca5c3c02fda5a5540050dc6850602f020f7dcf0 --- /dev/null +++ b/web/src/index.js @@ -0,0 +1,31 @@ +import React from 'react'; +import ReactDOM from 'react-dom/client'; +import { BrowserRouter } from 'react-router-dom'; +import { Container } from 'semantic-ui-react'; +import App from './App'; +import Header from './components/Header'; +import Footer from './components/Footer'; +import 'semantic-ui-css/semantic.min.css'; +import './index.css'; +import { UserProvider } from './context/User'; +import { ToastContainer } from 'react-toastify'; +import 'react-toastify/dist/ReactToastify.css'; +import { StatusProvider } from './context/Status'; + +const root = ReactDOM.createRoot(document.getElementById('root')); +root.render( + + + + +
+ + + + +