Spaces:
Build error
Build error
| package model | |
| import ( | |
| "database/sql/driver" | |
| "encoding/json" | |
| "one-api/constant" | |
| commonRelay "one-api/relay/common" | |
| "time" | |
| ) | |
| type TaskStatus string | |
| const ( | |
| TaskStatusNotStart TaskStatus = "NOT_START" | |
| TaskStatusSubmitted = "SUBMITTED" | |
| TaskStatusQueued = "QUEUED" | |
| TaskStatusInProgress = "IN_PROGRESS" | |
| TaskStatusFailure = "FAILURE" | |
| TaskStatusSuccess = "SUCCESS" | |
| TaskStatusUnknown = "UNKNOWN" | |
| ) | |
| type Task struct { | |
| ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"` | |
| CreatedAt int64 `json:"created_at" gorm:"index"` | |
| UpdatedAt int64 `json:"updated_at"` | |
| TaskID string `json:"task_id" gorm:"type:varchar(191);index"` // 第三方id,不一定有/ song id\ Task id | |
| Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台 | |
| UserId int `json:"user_id" gorm:"index"` | |
| ChannelId int `json:"channel_id" gorm:"index"` | |
| Quota int `json:"quota"` | |
| Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode | |
| Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态 | |
| FailReason string `json:"fail_reason"` | |
| SubmitTime int64 `json:"submit_time" gorm:"index"` | |
| StartTime int64 `json:"start_time" gorm:"index"` | |
| FinishTime int64 `json:"finish_time" gorm:"index"` | |
| Progress string `json:"progress" gorm:"type:varchar(20);index"` | |
| Properties Properties `json:"properties" gorm:"type:json"` | |
| Data json.RawMessage `json:"data" gorm:"type:json"` | |
| } | |
| func (t *Task) SetData(data any) { | |
| b, _ := json.Marshal(data) | |
| t.Data = json.RawMessage(b) | |
| } | |
| func (t *Task) GetData(v any) error { | |
| err := json.Unmarshal(t.Data, &v) | |
| return err | |
| } | |
| type Properties struct { | |
| Input string `json:"input"` | |
| } | |
| func (m *Properties) Scan(val interface{}) error { | |
| bytesValue, _ := val.([]byte) | |
| return json.Unmarshal(bytesValue, m) | |
| } | |
| func (m Properties) Value() (driver.Value, error) { | |
| return json.Marshal(m) | |
| } | |
| // SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 | |
| type SyncTaskQueryParams struct { | |
| Platform constant.TaskPlatform | |
| ChannelID string | |
| TaskID string | |
| UserID string | |
| Action string | |
| Status string | |
| StartTimestamp int64 | |
| EndTimestamp int64 | |
| UserIDs []int | |
| } | |
| func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task { | |
| t := &Task{ | |
| UserId: relayInfo.UserId, | |
| SubmitTime: time.Now().Unix(), | |
| Status: TaskStatusNotStart, | |
| Progress: "0%", | |
| ChannelId: relayInfo.ChannelId, | |
| Platform: platform, | |
| } | |
| return t | |
| } | |
| func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task { | |
| var tasks []*Task | |
| var err error | |
| // 初始化查询构建器 | |
| query := DB.Where("user_id = ?", userId) | |
| if queryParams.TaskID != "" { | |
| query = query.Where("task_id = ?", queryParams.TaskID) | |
| } | |
| if queryParams.Action != "" { | |
| query = query.Where("action = ?", queryParams.Action) | |
| } | |
| if queryParams.Status != "" { | |
| query = query.Where("status = ?", queryParams.Status) | |
| } | |
| if queryParams.Platform != "" { | |
| query = query.Where("platform = ?", queryParams.Platform) | |
| } | |
| if queryParams.StartTimestamp != 0 { | |
| // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析 | |
| query = query.Where("submit_time >= ?", queryParams.StartTimestamp) | |
| } | |
| if queryParams.EndTimestamp != 0 { | |
| query = query.Where("submit_time <= ?", queryParams.EndTimestamp) | |
| } | |
| // 获取数据 | |
| err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error | |
| if err != nil { | |
| return nil | |
| } | |
| return tasks | |
| } | |
| func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task { | |
| var tasks []*Task | |
| var err error | |
| // 初始化查询构建器 | |
| query := DB | |
| // 添加过滤条件 | |
| if queryParams.ChannelID != "" { | |
| query = query.Where("channel_id = ?", queryParams.ChannelID) | |
| } | |
| if queryParams.Platform != "" { | |
| query = query.Where("platform = ?", queryParams.Platform) | |
| } | |
| if queryParams.UserID != "" { | |
| query = query.Where("user_id = ?", queryParams.UserID) | |
| } | |
| if len(queryParams.UserIDs) != 0 { | |
| query = query.Where("user_id in (?)", queryParams.UserIDs) | |
| } | |
| if queryParams.TaskID != "" { | |
| query = query.Where("task_id = ?", queryParams.TaskID) | |
| } | |
| if queryParams.Action != "" { | |
| query = query.Where("action = ?", queryParams.Action) | |
| } | |
| if queryParams.Status != "" { | |
| query = query.Where("status = ?", queryParams.Status) | |
| } | |
| if queryParams.StartTimestamp != 0 { | |
| query = query.Where("submit_time >= ?", queryParams.StartTimestamp) | |
| } | |
| if queryParams.EndTimestamp != 0 { | |
| query = query.Where("submit_time <= ?", queryParams.EndTimestamp) | |
| } | |
| // 获取数据 | |
| err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error | |
| if err != nil { | |
| return nil | |
| } | |
| return tasks | |
| } | |
| func GetAllUnFinishSyncTasks(limit int) []*Task { | |
| var tasks []*Task | |
| var err error | |
| // get all tasks progress is not 100% | |
| err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error | |
| if err != nil { | |
| return nil | |
| } | |
| return tasks | |
| } | |
| func GetByOnlyTaskId(taskId string) (*Task, bool, error) { | |
| if taskId == "" { | |
| return nil, false, nil | |
| } | |
| var task *Task | |
| var err error | |
| err = DB.Where("task_id = ?", taskId).First(&task).Error | |
| exist, err := RecordExist(err) | |
| if err != nil { | |
| return nil, false, err | |
| } | |
| return task, exist, err | |
| } | |
| func GetByTaskId(userId int, taskId string) (*Task, bool, error) { | |
| if taskId == "" { | |
| return nil, false, nil | |
| } | |
| var task *Task | |
| var err error | |
| err = DB.Where("user_id = ? and task_id = ?", userId, taskId). | |
| First(&task).Error | |
| exist, err := RecordExist(err) | |
| if err != nil { | |
| return nil, false, err | |
| } | |
| return task, exist, err | |
| } | |
| func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) { | |
| if len(taskIds) == 0 { | |
| return nil, nil | |
| } | |
| var task []*Task | |
| var err error | |
| err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds). | |
| Find(&task).Error | |
| if err != nil { | |
| return nil, err | |
| } | |
| return task, nil | |
| } | |
| func TaskUpdateProgress(id int64, progress string) error { | |
| return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error | |
| } | |
| func (Task *Task) Insert() error { | |
| var err error | |
| err = DB.Create(Task).Error | |
| return err | |
| } | |
| func (Task *Task) Update() error { | |
| var err error | |
| err = DB.Save(Task).Error | |
| return err | |
| } | |
| func TaskBulkUpdate(TaskIds []string, params map[string]any) error { | |
| if len(TaskIds) == 0 { | |
| return nil | |
| } | |
| return DB.Model(&Task{}). | |
| Where("task_id in (?)", TaskIds). | |
| Updates(params).Error | |
| } | |
| func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error { | |
| if len(taskIDs) == 0 { | |
| return nil | |
| } | |
| return DB.Model(&Task{}). | |
| Where("id in (?)", taskIDs). | |
| Updates(params).Error | |
| } | |
| func TaskBulkUpdateByID(ids []int64, params map[string]any) error { | |
| if len(ids) == 0 { | |
| return nil | |
| } | |
| return DB.Model(&Task{}). | |
| Where("id in (?)", ids). | |
| Updates(params).Error | |
| } | |
| type TaskQuotaUsage struct { | |
| Mode string `json:"mode"` | |
| Count float64 `json:"count"` | |
| } | |
| func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) { | |
| query := DB.Model(Task{}) | |
| // 添加过滤条件 | |
| if queryParams.ChannelID != "" { | |
| query = query.Where("channel_id = ?", queryParams.ChannelID) | |
| } | |
| if queryParams.UserID != "" { | |
| query = query.Where("user_id = ?", queryParams.UserID) | |
| } | |
| if len(queryParams.UserIDs) != 0 { | |
| query = query.Where("user_id in (?)", queryParams.UserIDs) | |
| } | |
| if queryParams.TaskID != "" { | |
| query = query.Where("task_id = ?", queryParams.TaskID) | |
| } | |
| if queryParams.Action != "" { | |
| query = query.Where("action = ?", queryParams.Action) | |
| } | |
| if queryParams.Status != "" { | |
| query = query.Where("status = ?", queryParams.Status) | |
| } | |
| if queryParams.StartTimestamp != 0 { | |
| query = query.Where("submit_time >= ?", queryParams.StartTimestamp) | |
| } | |
| if queryParams.EndTimestamp != 0 { | |
| query = query.Where("submit_time <= ?", queryParams.EndTimestamp) | |
| } | |
| err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error | |
| return stat, err | |
| } | |
| // TaskCountAllTasks returns total tasks that match the given query params (admin usage) | |
| func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 { | |
| var total int64 | |
| query := DB.Model(&Task{}) | |
| if queryParams.ChannelID != "" { | |
| query = query.Where("channel_id = ?", queryParams.ChannelID) | |
| } | |
| if queryParams.Platform != "" { | |
| query = query.Where("platform = ?", queryParams.Platform) | |
| } | |
| if queryParams.UserID != "" { | |
| query = query.Where("user_id = ?", queryParams.UserID) | |
| } | |
| if len(queryParams.UserIDs) != 0 { | |
| query = query.Where("user_id in (?)", queryParams.UserIDs) | |
| } | |
| if queryParams.TaskID != "" { | |
| query = query.Where("task_id = ?", queryParams.TaskID) | |
| } | |
| if queryParams.Action != "" { | |
| query = query.Where("action = ?", queryParams.Action) | |
| } | |
| if queryParams.Status != "" { | |
| query = query.Where("status = ?", queryParams.Status) | |
| } | |
| if queryParams.StartTimestamp != 0 { | |
| query = query.Where("submit_time >= ?", queryParams.StartTimestamp) | |
| } | |
| if queryParams.EndTimestamp != 0 { | |
| query = query.Where("submit_time <= ?", queryParams.EndTimestamp) | |
| } | |
| _ = query.Count(&total).Error | |
| return total | |
| } | |
| // TaskCountAllUserTask returns total tasks for given user | |
| func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 { | |
| var total int64 | |
| query := DB.Model(&Task{}).Where("user_id = ?", userId) | |
| if queryParams.TaskID != "" { | |
| query = query.Where("task_id = ?", queryParams.TaskID) | |
| } | |
| if queryParams.Action != "" { | |
| query = query.Where("action = ?", queryParams.Action) | |
| } | |
| if queryParams.Status != "" { | |
| query = query.Where("status = ?", queryParams.Status) | |
| } | |
| if queryParams.Platform != "" { | |
| query = query.Where("platform = ?", queryParams.Platform) | |
| } | |
| if queryParams.StartTimestamp != 0 { | |
| query = query.Where("submit_time >= ?", queryParams.StartTimestamp) | |
| } | |
| if queryParams.EndTimestamp != 0 { | |
| query = query.Where("submit_time <= ?", queryParams.EndTimestamp) | |
| } | |
| _ = query.Count(&total).Error | |
| return total | |
| } | |