dvc890's picture
Update api/chatgpt/api.go
8df2678
package chatgpt
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"strings"
http "github.com/bogdanfinn/fhttp"
"github.com/gin-gonic/gin"
"github.com/linweiyuan/go-chatgpt-api/api"
"github.com/linweiyuan/go-logger/logger"
)
func CreateConversation(c *gin.Context) {
var request CreateConversationRequest
if err := c.BindJSON(&request); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage))
return
}
if request.ConversationID == nil || *request.ConversationID == "" {
request.ConversationID = nil
}
if len(request.Messages) != 0 {
message := request.Messages[0]
if message.Author.Role == "" {
message.Author.Role = defaultRole
}
if message.Metadata == nil {
message.Metadata = map[string]string{}
}
request.Messages[0] = message
}
if strings.HasPrefix(request.Model, gpt4Model) && request.ArkoseToken == "" {
arkoseToken, err := api.GetArkoseToken()
if err != nil || arkoseToken == "" {
c.AbortWithStatusJSON(http.StatusForbidden, api.ReturnMessage(err.Error()))
return
}
request.ArkoseToken = arkoseToken
}
resp, done := sendConversationRequest(c, request)
if done {
return
}
handleConversationResponse(c, resp, request)
}
func sendConversationRequest(c *gin.Context, request CreateConversationRequest) (*http.Response, bool) {
jsonBytes, _ := json.Marshal(request)
req, _ := http.NewRequest(http.MethodPost, api.ChatGPTApiUrlPrefix+"/backend-api/conversation", bytes.NewBuffer(jsonBytes))
req.Header.Set("User-Agent", api.UserAgent)
req.Header.Set(api.AuthorizationHeader, api.GetAccessToken(c))
req.Header.Set("Accept", "text/event-stream")
if api.PUID != "" {
req.Header.Set("Cookie", "_puid="+api.PUID)
}
resp, err := api.Client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return nil, true
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized {
logger.Error(fmt.Sprintf(api.AccountDeactivatedErrorMessage, c.GetString(api.EmailKey)))
responseMap := make(map[string]interface{})
json.NewDecoder(resp.Body).Decode(&responseMap)
c.AbortWithStatusJSON(resp.StatusCode, responseMap)
return nil, true
}
req, _ := http.NewRequest(http.MethodGet, api.ChatGPTApiUrlPrefix+"/backend-api/models?history_and_training_disabled=false", nil)
req.Header.Set("User-Agent", api.UserAgent)
req.Header.Set(api.AuthorizationHeader, api.GetAccessToken(c))
response, err := api.Client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return nil, true
}
defer response.Body.Close()
modelAvailable := false
var getModelsResponse GetModelsResponse
json.NewDecoder(response.Body).Decode(&getModelsResponse)
for _, model := range getModelsResponse.Models {
if model.Slug == request.Model {
modelAvailable = true
break
}
}
if !modelAvailable {
c.AbortWithStatusJSON(http.StatusForbidden, api.ReturnMessage(noModelPermissionErrorMessage))
return nil, true
}
data, _ := io.ReadAll(resp.Body)
logger.Warn(string(data))
responseMap := make(map[string]interface{})
json.NewDecoder(resp.Body).Decode(&responseMap)
c.AbortWithStatusJSON(resp.StatusCode, responseMap)
return nil, true
}
return resp, false
}
func handleConversationResponse(c *gin.Context, resp *http.Response, request CreateConversationRequest) {
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
isMaxTokens := false
continueParentMessageID := ""
continueConversationID := ""
request.AutoContinue = true
defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
for {
if c.Request.Context().Err() != nil {
break
}
line, err := reader.ReadString('\n')
if err != nil {
break
}
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "event") ||
strings.HasPrefix(line, "data: 20") ||
strings.HasPrefix(line, `data: {"conversation_id"`) ||
line == "" {
continue
}
responseJson := line[6:]
if strings.HasPrefix(responseJson, "[DONE]") && isMaxTokens && request.AutoContinue {
continue
}
// no need to unmarshal every time, but if response content has this "max_tokens", need to further check
if strings.TrimSpace(responseJson) != "" && strings.Contains(responseJson, responseTypeMaxTokens) {
var createConversationResponse CreateConversationResponse
json.Unmarshal([]byte(responseJson), &createConversationResponse)
message := createConversationResponse.Message
if message.Metadata.FinishDetails.Type == responseTypeMaxTokens && createConversationResponse.Message.Status == responseStatusFinishedSuccessfully {
isMaxTokens = true
continueParentMessageID = message.ID
continueConversationID = createConversationResponse.ConversationID
}
}
c.Writer.Write([]byte(line + "\n\n"))
c.Writer.Flush()
}
if isMaxTokens && request.AutoContinue {
continueConversationRequest := CreateConversationRequest{
ArkoseToken: request.ArkoseToken,
HistoryAndTrainingDisabled: request.HistoryAndTrainingDisabled,
Model: request.Model,
TimezoneOffsetMin: request.TimezoneOffsetMin,
Action: actionContinue,
ParentMessageID: continueParentMessageID,
ConversationID: &continueConversationID,
}
resp, done := sendConversationRequest(c, continueConversationRequest)
if done {
return
}
handleConversationResponse(c, resp, continueConversationRequest)
}
}