| package admin |
|
|
| import ( |
| "strconv" |
| "strings" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/handler/dto" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/openai" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/response" |
| "github.com/Wei-Shaw/sub2api/internal/service" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| |
| type OpenAIOAuthHandler struct { |
| openaiOAuthService *service.OpenAIOAuthService |
| adminService service.AdminService |
| } |
|
|
| func oauthPlatformFromPath(c *gin.Context) string { |
| if strings.Contains(c.FullPath(), "/admin/sora/") { |
| return service.PlatformSora |
| } |
| return service.PlatformOpenAI |
| } |
|
|
| |
| func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { |
| return &OpenAIOAuthHandler{ |
| openaiOAuthService: openaiOAuthService, |
| adminService: adminService, |
| } |
| } |
|
|
| |
| type OpenAIGenerateAuthURLRequest struct { |
| ProxyID *int64 `json:"proxy_id"` |
| RedirectURI string `json:"redirect_uri"` |
| } |
|
|
| |
| |
| func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { |
| var req OpenAIGenerateAuthURLRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| |
| req = OpenAIGenerateAuthURLRequest{} |
| } |
|
|
| result, err := h.openaiOAuthService.GenerateAuthURL( |
| c.Request.Context(), |
| req.ProxyID, |
| req.RedirectURI, |
| oauthPlatformFromPath(c), |
| ) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, result) |
| } |
|
|
| |
| type OpenAIExchangeCodeRequest struct { |
| SessionID string `json:"session_id" binding:"required"` |
| Code string `json:"code" binding:"required"` |
| State string `json:"state" binding:"required"` |
| RedirectURI string `json:"redirect_uri"` |
| ProxyID *int64 `json:"proxy_id"` |
| } |
|
|
| |
| |
| func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { |
| var req OpenAIExchangeCodeRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ |
| SessionID: req.SessionID, |
| Code: req.Code, |
| State: req.State, |
| RedirectURI: req.RedirectURI, |
| ProxyID: req.ProxyID, |
| }) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, tokenInfo) |
| } |
|
|
| |
| type OpenAIRefreshTokenRequest struct { |
| RefreshToken string `json:"refresh_token"` |
| RT string `json:"rt"` |
| ClientID string `json:"client_id"` |
| ProxyID *int64 `json:"proxy_id"` |
| } |
|
|
| |
| |
| |
| func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { |
| var req OpenAIRefreshTokenRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
| refreshToken := strings.TrimSpace(req.RefreshToken) |
| if refreshToken == "" { |
| refreshToken = strings.TrimSpace(req.RT) |
| } |
| if refreshToken == "" { |
| response.BadRequest(c, "refresh_token is required") |
| return |
| } |
|
|
| var proxyURL string |
| if req.ProxyID != nil { |
| proxy, err := h.adminService.GetProxy(c.Request.Context(), *req.ProxyID) |
| if err == nil && proxy != nil { |
| proxyURL = proxy.URL() |
| } |
| } |
|
|
| |
| clientID := strings.TrimSpace(req.ClientID) |
| if clientID == "" { |
| platform := oauthPlatformFromPath(c) |
| clientID, _ = openai.OAuthClientConfigByPlatform(platform) |
| } |
|
|
| tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, tokenInfo) |
| } |
|
|
| |
| |
| func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) { |
| var req struct { |
| SessionToken string `json:"session_token"` |
| ST string `json:"st"` |
| ProxyID *int64 `json:"proxy_id"` |
| } |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| sessionToken := strings.TrimSpace(req.SessionToken) |
| if sessionToken == "" { |
| sessionToken = strings.TrimSpace(req.ST) |
| } |
| if sessionToken == "" { |
| response.BadRequest(c, "session_token is required") |
| return |
| } |
|
|
| tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
| response.Success(c, tokenInfo) |
| } |
|
|
| |
| |
| |
| func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { |
| accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) |
| if err != nil { |
| response.BadRequest(c, "Invalid account ID") |
| return |
| } |
|
|
| |
| account, err := h.adminService.GetAccount(c.Request.Context(), accountID) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| platform := oauthPlatformFromPath(c) |
| if account.Platform != platform { |
| response.BadRequest(c, "Account platform does not match OAuth endpoint") |
| return |
| } |
|
|
| |
| if !account.IsOAuth() { |
| response.BadRequest(c, "Cannot refresh non-OAuth account credentials") |
| return |
| } |
|
|
| |
| tokenInfo, err := h.openaiOAuthService.RefreshAccountToken(c.Request.Context(), account) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| |
| newCredentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) |
|
|
| |
| for k, v := range account.Credentials { |
| if _, exists := newCredentials[k]; !exists { |
| newCredentials[k] = v |
| } |
| } |
|
|
| updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ |
| Credentials: newCredentials, |
| }) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, dto.AccountFromService(updatedAccount)) |
| } |
|
|
| |
| |
| |
| func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { |
| var req struct { |
| SessionID string `json:"session_id" binding:"required"` |
| Code string `json:"code" binding:"required"` |
| State string `json:"state" binding:"required"` |
| RedirectURI string `json:"redirect_uri"` |
| ProxyID *int64 `json:"proxy_id"` |
| Name string `json:"name"` |
| Concurrency int `json:"concurrency"` |
| Priority int `json:"priority"` |
| GroupIDs []int64 `json:"group_ids"` |
| } |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| |
| tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ |
| SessionID: req.SessionID, |
| Code: req.Code, |
| State: req.State, |
| RedirectURI: req.RedirectURI, |
| ProxyID: req.ProxyID, |
| }) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| |
| credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) |
|
|
| platform := oauthPlatformFromPath(c) |
|
|
| |
| name := req.Name |
| if name == "" && tokenInfo.Email != "" { |
| name = tokenInfo.Email |
| } |
| if name == "" { |
| if platform == service.PlatformSora { |
| name = "Sora OAuth Account" |
| } else { |
| name = "OpenAI OAuth Account" |
| } |
| } |
|
|
| |
| account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ |
| Name: name, |
| Platform: platform, |
| Type: "oauth", |
| Credentials: credentials, |
| Extra: nil, |
| ProxyID: req.ProxyID, |
| Concurrency: req.Concurrency, |
| Priority: req.Priority, |
| GroupIDs: req.GroupIDs, |
| }) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, dto.AccountFromService(account)) |
| } |
|
|