Spaces:
Paused
Paused
package chatgpt | |
import ( | |
"bufio" | |
"bytes" | |
"encoding/json" | |
"fmt" | |
"io" | |
"strings" | |
http "github.com/bogdanfinn/fhttp" | |
"github.com/gin-gonic/gin" | |
"github.com/linweiyuan/go-chatgpt-api/api" | |
"github.com/linweiyuan/go-logger/logger" | |
) | |
func CreateConversation(c *gin.Context) { | |
var request CreateConversationRequest | |
if err := c.BindJSON(&request); err != nil { | |
c.AbortWithStatusJSON(http.StatusBadRequest, api.ReturnMessage(parseJsonErrorMessage)) | |
return | |
} | |
if request.ConversationID == nil || *request.ConversationID == "" { | |
request.ConversationID = nil | |
} | |
if len(request.Messages) != 0 { | |
message := request.Messages[0] | |
if message.Author.Role == "" { | |
message.Author.Role = defaultRole | |
} | |
if message.Metadata == nil { | |
message.Metadata = map[string]string{} | |
} | |
request.Messages[0] = message | |
} | |
if strings.HasPrefix(request.Model, gpt4Model) && request.ArkoseToken == "" { | |
arkoseToken, err := api.GetArkoseToken() | |
if err != nil || arkoseToken == "" { | |
c.AbortWithStatusJSON(http.StatusForbidden, api.ReturnMessage(err.Error())) | |
return | |
} | |
request.ArkoseToken = arkoseToken | |
} | |
resp, done := sendConversationRequest(c, request) | |
if done { | |
return | |
} | |
handleConversationResponse(c, resp, request) | |
} | |
func sendConversationRequest(c *gin.Context, request CreateConversationRequest) (*http.Response, bool) { | |
jsonBytes, _ := json.Marshal(request) | |
req, _ := http.NewRequest(http.MethodPost, api.ChatGPTApiUrlPrefix+"/backend-api/conversation", bytes.NewBuffer(jsonBytes)) | |
req.Header.Set("User-Agent", api.UserAgent) | |
req.Header.Set(api.AuthorizationHeader, api.GetAccessToken(c)) | |
req.Header.Set("Accept", "text/event-stream") | |
if api.PUID != "" { | |
req.Header.Set("Cookie", "_puid="+api.PUID) | |
} | |
resp, err := api.Client.Do(req) | |
if err != nil { | |
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error())) | |
return nil, true | |
} | |
if resp.StatusCode != http.StatusOK { | |
defer resp.Body.Close() | |
if resp.StatusCode == http.StatusUnauthorized { | |
logger.Error(fmt.Sprintf(api.AccountDeactivatedErrorMessage, c.GetString(api.EmailKey))) | |
responseMap := make(map[string]interface{}) | |
json.NewDecoder(resp.Body).Decode(&responseMap) | |
c.AbortWithStatusJSON(resp.StatusCode, responseMap) | |
return nil, true | |
} | |
req, _ := http.NewRequest(http.MethodGet, api.ChatGPTApiUrlPrefix+"/backend-api/models?history_and_training_disabled=false", nil) | |
req.Header.Set("User-Agent", api.UserAgent) | |
req.Header.Set(api.AuthorizationHeader, api.GetAccessToken(c)) | |
response, err := api.Client.Do(req) | |
if err != nil { | |
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error())) | |
return nil, true | |
} | |
defer response.Body.Close() | |
modelAvailable := false | |
var getModelsResponse GetModelsResponse | |
json.NewDecoder(response.Body).Decode(&getModelsResponse) | |
for _, model := range getModelsResponse.Models { | |
if model.Slug == request.Model { | |
modelAvailable = true | |
break | |
} | |
} | |
if !modelAvailable { | |
c.AbortWithStatusJSON(http.StatusForbidden, api.ReturnMessage(noModelPermissionErrorMessage)) | |
return nil, true | |
} | |
data, _ := io.ReadAll(resp.Body) | |
logger.Warn(string(data)) | |
responseMap := make(map[string]interface{}) | |
json.NewDecoder(resp.Body).Decode(&responseMap) | |
c.AbortWithStatusJSON(resp.StatusCode, responseMap) | |
return nil, true | |
} | |
return resp, false | |
} | |
func handleConversationResponse(c *gin.Context, resp *http.Response, request CreateConversationRequest) { | |
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8") | |
isMaxTokens := false | |
continueParentMessageID := "" | |
continueConversationID := "" | |
request.AutoContinue = true | |
defer resp.Body.Close() | |
reader := bufio.NewReader(resp.Body) | |
for { | |
if c.Request.Context().Err() != nil { | |
break | |
} | |
line, err := reader.ReadString('\n') | |
if err != nil { | |
break | |
} | |
line = strings.TrimSpace(line) | |
if strings.HasPrefix(line, "event") || | |
strings.HasPrefix(line, "data: 20") || | |
strings.HasPrefix(line, `data: {"conversation_id"`) || | |
line == "" { | |
continue | |
} | |
responseJson := line[6:] | |
if strings.HasPrefix(responseJson, "[DONE]") && isMaxTokens && request.AutoContinue { | |
continue | |
} | |
// no need to unmarshal every time, but if response content has this "max_tokens", need to further check | |
if strings.TrimSpace(responseJson) != "" && strings.Contains(responseJson, responseTypeMaxTokens) { | |
var createConversationResponse CreateConversationResponse | |
json.Unmarshal([]byte(responseJson), &createConversationResponse) | |
message := createConversationResponse.Message | |
if message.Metadata.FinishDetails.Type == responseTypeMaxTokens && createConversationResponse.Message.Status == responseStatusFinishedSuccessfully { | |
isMaxTokens = true | |
continueParentMessageID = message.ID | |
continueConversationID = createConversationResponse.ConversationID | |
} | |
} | |
c.Writer.Write([]byte(line + "\n\n")) | |
c.Writer.Flush() | |
} | |
if isMaxTokens && request.AutoContinue { | |
continueConversationRequest := CreateConversationRequest{ | |
ArkoseToken: request.ArkoseToken, | |
HistoryAndTrainingDisabled: request.HistoryAndTrainingDisabled, | |
Model: request.Model, | |
TimezoneOffsetMin: request.TimezoneOffsetMin, | |
Action: actionContinue, | |
ParentMessageID: continueParentMessageID, | |
ConversationID: &continueConversationID, | |
} | |
resp, done := sendConversationRequest(c, continueConversationRequest) | |
if done { | |
return | |
} | |
handleConversationResponse(c, resp, continueConversationRequest) | |
} | |
} | |