| package model |
|
|
| import ( |
| "fmt" |
| "log" |
| "os" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
|
|
| "github.com/glebarez/sqlite" |
| "gorm.io/driver/mysql" |
| "gorm.io/driver/postgres" |
| "gorm.io/gorm" |
| ) |
|
|
| var commonGroupCol string |
| var commonKeyCol string |
| var commonTrueVal string |
| var commonFalseVal string |
|
|
| var logKeyCol string |
| var logGroupCol string |
|
|
| func initCol() { |
| |
| if common.UsingPostgreSQL { |
| commonGroupCol = `"group"` |
| commonKeyCol = `"key"` |
| commonTrueVal = "true" |
| commonFalseVal = "false" |
| } else { |
| commonGroupCol = "`group`" |
| commonKeyCol = "`key`" |
| commonTrueVal = "1" |
| commonFalseVal = "0" |
| } |
| if os.Getenv("LOG_SQL_DSN") != "" { |
| switch common.LogSqlType { |
| case common.DatabaseTypePostgreSQL: |
| logGroupCol = `"group"` |
| logKeyCol = `"key"` |
| default: |
| logGroupCol = commonGroupCol |
| logKeyCol = commonKeyCol |
| } |
| } else { |
| |
| if common.UsingPostgreSQL { |
| logGroupCol = `"group"` |
| logKeyCol = `"key"` |
| } else { |
| logGroupCol = commonGroupCol |
| logKeyCol = commonKeyCol |
| } |
| } |
| |
| |
| } |
|
|
| var DB *gorm.DB |
|
|
| var LOG_DB *gorm.DB |
|
|
| func createRootAccountIfNeed() error { |
| var user User |
| |
| if err := DB.First(&user).Error; err != nil { |
| common.SysLog("no user exists, create a root user for you: username is root, password is 123456") |
| hashedPassword, err := common.Password2Hash("123456") |
| if err != nil { |
| return err |
| } |
| rootUser := User{ |
| Username: "root", |
| Password: hashedPassword, |
| Role: common.RoleRootUser, |
| Status: common.UserStatusEnabled, |
| DisplayName: "Root User", |
| AccessToken: nil, |
| Quota: 100000000, |
| } |
| DB.Create(&rootUser) |
| } |
| return nil |
| } |
|
|
| func CheckSetup() { |
| setup := GetSetup() |
| if setup == nil { |
| |
| if RootUserExists() { |
| common.SysLog("system is not initialized, but root user exists") |
| |
| newSetup := Setup{ |
| Version: common.Version, |
| InitializedAt: time.Now().Unix(), |
| } |
| err := DB.Create(&newSetup).Error |
| if err != nil { |
| common.SysLog("failed to create setup record: " + err.Error()) |
| } |
| constant.Setup = true |
| } else { |
| common.SysLog("system is not initialized and no root user exists") |
| constant.Setup = false |
| } |
| } else { |
| |
| common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) |
| constant.Setup = true |
| } |
| } |
|
|
| func chooseDB(envName string, isLog bool) (*gorm.DB, error) { |
| defer func() { |
| initCol() |
| }() |
| dsn := os.Getenv(envName) |
| if dsn != "" { |
| if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { |
| |
| common.SysLog("using PostgreSQL as database") |
| if !isLog { |
| common.UsingPostgreSQL = true |
| } else { |
| common.LogSqlType = common.DatabaseTypePostgreSQL |
| } |
| return gorm.Open(postgres.New(postgres.Config{ |
| DSN: dsn, |
| PreferSimpleProtocol: true, |
| }), &gorm.Config{ |
| PrepareStmt: true, |
| }) |
| } |
| if strings.HasPrefix(dsn, "local") { |
| common.SysLog("SQL_DSN not set, using SQLite as database") |
| if !isLog { |
| common.UsingSQLite = true |
| } else { |
| common.LogSqlType = common.DatabaseTypeSQLite |
| } |
| return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ |
| PrepareStmt: true, |
| }) |
| } |
| |
| common.SysLog("using MySQL as database") |
| |
| if !strings.Contains(dsn, "parseTime") { |
| if strings.Contains(dsn, "?") { |
| dsn += "&parseTime=true" |
| } else { |
| dsn += "?parseTime=true" |
| } |
| } |
| if !isLog { |
| common.UsingMySQL = true |
| } else { |
| common.LogSqlType = common.DatabaseTypeMySQL |
| } |
| return gorm.Open(mysql.Open(dsn), &gorm.Config{ |
| PrepareStmt: true, |
| }) |
| } |
| |
| common.SysLog("SQL_DSN not set, using SQLite as database") |
| common.UsingSQLite = true |
| return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ |
| PrepareStmt: true, |
| }) |
| } |
|
|
| func InitDB() (err error) { |
| db, err := chooseDB("SQL_DSN", false) |
| if err == nil { |
| if common.DebugEnabled { |
| db = db.Debug() |
| } |
| DB = db |
| |
| if common.UsingMySQL { |
| if err := checkMySQLChineseSupport(DB); err != nil { |
| panic(err) |
| } |
| } |
| sqlDB, err := DB.DB() |
| if err != nil { |
| return err |
| } |
| sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100)) |
| sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000)) |
| sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60))) |
|
|
| if !common.IsMasterNode { |
| return nil |
| } |
| if common.UsingMySQL { |
| |
| } |
| common.SysLog("database migration started") |
| err = migrateDB() |
| return err |
| } else { |
| common.FatalLog(err) |
| } |
| return err |
| } |
|
|
| func InitLogDB() (err error) { |
| if os.Getenv("LOG_SQL_DSN") == "" { |
| LOG_DB = DB |
| return |
| } |
| db, err := chooseDB("LOG_SQL_DSN", true) |
| if err == nil { |
| if common.DebugEnabled { |
| db = db.Debug() |
| } |
| LOG_DB = db |
| |
| if common.LogSqlType == common.DatabaseTypeMySQL { |
| if err := checkMySQLChineseSupport(LOG_DB); err != nil { |
| panic(err) |
| } |
| } |
| sqlDB, err := LOG_DB.DB() |
| if err != nil { |
| return err |
| } |
| sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100)) |
| sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000)) |
| sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60))) |
|
|
| if !common.IsMasterNode { |
| return nil |
| } |
| common.SysLog("database migration started") |
| err = migrateLOGDB() |
| return err |
| } else { |
| common.FatalLog(err) |
| } |
| return err |
| } |
|
|
| func migrateDB() error { |
| err := DB.AutoMigrate( |
| &Channel{}, |
| &Token{}, |
| &User{}, |
| &PasskeyCredential{}, |
| &Option{}, |
| &Redemption{}, |
| &Ability{}, |
| &Log{}, |
| &Midjourney{}, |
| &TopUp{}, |
| &QuotaData{}, |
| &Task{}, |
| &Model{}, |
| &Vendor{}, |
| &PrefillGroup{}, |
| &Setup{}, |
| &TwoFA{}, |
| &TwoFABackupCode{}, |
| ) |
| if err != nil { |
| return err |
| } |
| return nil |
| } |
|
|
| func migrateDBFast() error { |
|
|
| var wg sync.WaitGroup |
|
|
| migrations := []struct { |
| model interface{} |
| name string |
| }{ |
| {&Channel{}, "Channel"}, |
| {&Token{}, "Token"}, |
| {&User{}, "User"}, |
| {&PasskeyCredential{}, "PasskeyCredential"}, |
| {&Option{}, "Option"}, |
| {&Redemption{}, "Redemption"}, |
| {&Ability{}, "Ability"}, |
| {&Log{}, "Log"}, |
| {&Midjourney{}, "Midjourney"}, |
| {&TopUp{}, "TopUp"}, |
| {&QuotaData{}, "QuotaData"}, |
| {&Task{}, "Task"}, |
| {&Model{}, "Model"}, |
| {&Vendor{}, "Vendor"}, |
| {&PrefillGroup{}, "PrefillGroup"}, |
| {&Setup{}, "Setup"}, |
| {&TwoFA{}, "TwoFA"}, |
| {&TwoFABackupCode{}, "TwoFABackupCode"}, |
| } |
| |
| errChan := make(chan error, len(migrations)) |
|
|
| for _, m := range migrations { |
| wg.Add(1) |
| go func(model interface{}, name string) { |
| defer wg.Done() |
| if err := DB.AutoMigrate(model); err != nil { |
| errChan <- fmt.Errorf("failed to migrate %s: %v", name, err) |
| } |
| }(m.model, m.name) |
| } |
|
|
| |
| wg.Wait() |
| close(errChan) |
|
|
| |
| for err := range errChan { |
| if err != nil { |
| return err |
| } |
| } |
| common.SysLog("database migrated") |
| return nil |
| } |
|
|
| func migrateLOGDB() error { |
| var err error |
| if err = LOG_DB.AutoMigrate(&Log{}); err != nil { |
| return err |
| } |
| return nil |
| } |
|
|
| func closeDB(db *gorm.DB) error { |
| sqlDB, err := db.DB() |
| if err != nil { |
| return err |
| } |
| err = sqlDB.Close() |
| return err |
| } |
|
|
| func CloseDB() error { |
| if LOG_DB != DB { |
| err := closeDB(LOG_DB) |
| if err != nil { |
| return err |
| } |
| } |
| return closeDB(DB) |
| } |
|
|
| |
| |
| |
| func checkMySQLChineseSupport(db *gorm.DB) error { |
| |
|
|
| |
| var schemaCharset, schemaCollation string |
| err := db.Raw("SELECT DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = DATABASE()").Row().Scan(&schemaCharset, &schemaCollation) |
| if err != nil { |
| return fmt.Errorf("读取当前库默认字符集/排序规则失败 / Failed to read schema default charset/collation: %v", err) |
| } |
|
|
| toLower := func(s string) string { return strings.ToLower(s) } |
| |
| allowedCharsets := map[string]string{ |
| "utf8mb4": "utf8mb4_", |
| "utf8": "utf8_", |
| "gbk": "gbk_", |
| "big5": "big5_", |
| "gb18030": "gb18030_", |
| } |
| isChineseCapable := func(cs, cl string) bool { |
| csLower := toLower(cs) |
| clLower := toLower(cl) |
| if prefix, ok := allowedCharsets[csLower]; ok { |
| if clLower == "" { |
| return true |
| } |
| return strings.HasPrefix(clLower, prefix) |
| } |
| |
| for _, prefix := range allowedCharsets { |
| if strings.HasPrefix(clLower, prefix) { |
| return true |
| } |
| } |
| return false |
| } |
|
|
| |
| if !isChineseCapable(schemaCharset, schemaCollation) { |
| return fmt.Errorf("当前库默认字符集/排序规则不支持中文:schema(%s/%s)。请将库设置为 utf8mb4/utf8/gbk/big5/gb18030 / Schema default charset/collation is not Chinese-capable: schema(%s/%s). Please set to utf8mb4/utf8/gbk/big5/gb18030", |
| schemaCharset, schemaCollation, schemaCharset, schemaCollation) |
| } |
|
|
| |
| type tableInfo struct { |
| Name string |
| Collation *string |
| } |
| var tables []tableInfo |
| if err := db.Raw("SELECT TABLE_NAME, TABLE_COLLATION FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'").Scan(&tables).Error; err != nil { |
| return fmt.Errorf("读取表排序规则失败 / Failed to read table collations: %v", err) |
| } |
|
|
| var badTables []string |
| for _, t := range tables { |
| |
| if t.Collation == nil || *t.Collation == "" { |
| continue |
| } |
| cl := *t.Collation |
| |
| ok := false |
| lower := strings.ToLower(cl) |
| for _, prefix := range allowedCharsets { |
| if strings.HasPrefix(lower, prefix) { |
| ok = true |
| break |
| } |
| } |
| if !ok { |
| badTables = append(badTables, fmt.Sprintf("%s(%s)", t.Name, cl)) |
| } |
| } |
|
|
| if len(badTables) > 0 { |
| |
| maxShow := 20 |
| shown := badTables |
| if len(shown) > maxShow { |
| shown = shown[:maxShow] |
| } |
| return fmt.Errorf( |
| "存在不支持中文的表,请修复其排序规则/字符集。示例(最多展示 %d 项):%v / Found tables not Chinese-capable. Please fix their collation/charset. Examples (showing up to %d): %v", |
| maxShow, shown, maxShow, shown, |
| ) |
| } |
| return nil |
| } |
|
|
| var ( |
| lastPingTime time.Time |
| pingMutex sync.Mutex |
| ) |
|
|
| func PingDB() error { |
| pingMutex.Lock() |
| defer pingMutex.Unlock() |
|
|
| if time.Since(lastPingTime) < time.Second*10 { |
| return nil |
| } |
|
|
| sqlDB, err := DB.DB() |
| if err != nil { |
| log.Printf("Error getting sql.DB from GORM: %v", err) |
| return err |
| } |
|
|
| err = sqlDB.Ping() |
| if err != nil { |
| log.Printf("Error pinging DB: %v", err) |
| return err |
| } |
|
|
| lastPingTime = time.Now() |
| common.SysLog("Database pinged successfully") |
| return nil |
| } |
|
|