Spaces:
Runtime error
Runtime error
package wsstostream | |
import ( | |
"WarpGPT/pkg/common" | |
"WarpGPT/pkg/env" | |
"WarpGPT/pkg/logger" | |
"WarpGPT/pkg/tools" | |
"bytes" | |
"encoding/base64" | |
"encoding/json" | |
"errors" | |
http "github.com/bogdanfinn/fhttp" | |
"github.com/gorilla/websocket" | |
"golang.org/x/net/proxy" | |
"io" | |
shttp "net/http" | |
"net/url" | |
"time" | |
) | |
type RegisterWebsocket struct { | |
ExpiresAt time.Time `json:"expires_at"` | |
WssUrl string `json:"wss_url"` | |
} | |
type WsResponse struct { | |
SequenceId int `json:"sequenceId"` | |
Type string `json:"type"` | |
From string `json:"from"` | |
DataType string `json:"dataType"` | |
Data struct { | |
Type string `json:"type"` | |
Body string `json:"body"` | |
MoreBody bool `json:"more_body"` | |
ResponseId string `json:"response_id"` | |
ConversationId string `json:"conversation_id"` | |
MessageId string `json:"message_id"` | |
} `json:"data"` | |
} | |
type Reconnect struct { | |
Type string `json:"type"` | |
Event string `json:"event"` | |
UserId string `json:"userId"` | |
ConnectionId string `json:"connectionId"` | |
ReconnectionToken string `json:"reconnectionToken"` | |
} | |
type WssToStream struct { | |
ConversationId string | |
ResponseId string | |
AccessToken string | |
Server *websocket.Conn | |
WS *RegisterWebsocket | |
Reconnect | |
} | |
func NewWssToStream(accessToken string) *WssToStream { | |
return &WssToStream{AccessToken: accessToken} | |
} | |
func GetRegisterWebsocket(accessToken string) (*RegisterWebsocket, error) { | |
logger.Log.Debug("GetRegisterWebsocket") | |
WS, err := common.RequestOpenAI[RegisterWebsocket]("/backend-api/register-websocket", nil, accessToken, http.MethodPost) | |
if err != nil { | |
logger.Log.Error("Error decoding response:", err) | |
return nil, err | |
} | |
if WS != nil && WS.WssUrl != "" { | |
logger.Log.Debug("GetRegisterWebsocket Success WssUrl:", WS.WssUrl) | |
return WS, nil | |
} else { | |
logger.Log.Debug("accessToken:", accessToken) | |
return nil, errors.New("check your access_key") | |
} | |
} | |
func (s *WssToStream) InitConnect() error { | |
logger.Log.Debug("Try Connect To WS") | |
var dialer websocket.Dialer | |
// 当 env.E.Proxy 不为空字符串时,才配置代理 | |
if env.E.Proxy != "" { | |
proxyAddr, err := url.Parse(env.E.Proxy) | |
if err != nil { | |
logger.Log.Error("Error parsing proxy URL:", err) | |
return err | |
} | |
switch proxyAddr.Scheme { | |
case "http", "https": | |
dialer.Proxy = shttp.ProxyURL(proxyAddr) | |
case "socks5": | |
socksDialer, err := proxy.FromURL(proxyAddr, proxy.Direct) | |
if err != nil { | |
logger.Log.Error("Error creating SOCKS proxy dialer:", err) | |
return err | |
} | |
dialer.NetDial = socksDialer.Dial | |
default: | |
logger.Log.Error("Unsupported proxy scheme:", proxyAddr.Scheme) | |
return errors.New("unsupported proxy scheme") | |
} | |
} | |
headers := http.Header{} | |
headers.Set("Origin", "https://"+env.E.OpenaiHost) | |
headers.Set("Sec-WebSocket-Protocol", "json.reliable.webpubsub.azure.v1") | |
headers.Set("User-Agent", env.E.UserAgent) | |
item, exists := tools.AllCache.CacheGet(s.AccessToken) | |
if !exists || item.ExpiresAt.Before(time.Now()) { | |
registerWebsocket, err := GetRegisterWebsocket(s.AccessToken) | |
if err != nil { | |
return err | |
} | |
tools.AllCache.CacheSet(s.AccessToken, tools.CacheItem{Data: registerWebsocket}, 55*time.Minute) | |
s.WS = registerWebsocket | |
} else { | |
s.WS = item.Data.(*RegisterWebsocket) | |
} | |
c, _, err := dialer.Dial(s.WS.WssUrl, shttp.Header(headers)) | |
if err != nil { | |
logger.Log.Error("Dial error:", err) | |
return err | |
} | |
logger.Log.Debug("WS Connect Success") | |
s.Server = c | |
_, msg, err := s.Server.ReadMessage() | |
if err != nil { | |
return err | |
} | |
logger.Log.Debug("Init Read Message:", string(msg)) | |
return nil | |
} | |
type NopCloser struct { | |
*bytes.Reader | |
} | |
func (NopCloser) Close() error { | |
return nil | |
} | |
func NewNopCloser(data []byte) io.ReadCloser { | |
return NopCloser{Reader: bytes.NewReader(data)} | |
} | |
func (s *WssToStream) ReadMessage() (io.ReadCloser, error) { | |
logger.Log.Debug("Read Messages") | |
_, msg, err := s.Server.ReadMessage() | |
if err != nil { | |
logger.Log.Error("read message error:", err) | |
return nil, err | |
} | |
var response WsResponse | |
if err = json.Unmarshal(msg, &response); err != nil { | |
logger.Log.Error("unmarshal message error:", err) | |
return nil, err | |
} | |
if response.Data.ResponseId == s.ResponseId && response.Data.ConversationId == s.ConversationId { | |
data, err := base64.StdEncoding.DecodeString(response.Data.Body) | |
if err != nil { | |
logger.Log.Error("decode base64 message error:", err) | |
return nil, err | |
} | |
if response.Data.Body == "ZGF0YTogW0RPTkVdCgo=" { | |
s.Server.Close() | |
return NewNopCloser(data), io.EOF | |
} | |
return NewNopCloser(data), nil | |
} | |
return nil, errors.New("response id or conversation id does not match") | |
} | |
func (s *WssToStream) Read(p []byte) (n int, err error) { | |
logger.Log.Debug("Read") | |
_, message, err := s.Server.ReadMessage() | |
if err != nil { | |
return 0, err | |
} | |
var response WsResponse | |
if err = json.Unmarshal(message, &response); err != nil { | |
logger.Log.Error("unmarshal message error:", err) | |
return 0, err | |
} | |
if response.Data.ResponseId == s.ResponseId && response.Data.ConversationId == s.ConversationId { | |
data, err := base64.StdEncoding.DecodeString(response.Data.Body) | |
if err != nil { | |
logger.Log.Error("decode base64 message error:", err) | |
return 0, err | |
} | |
copyLen := copy(p, data) | |
if copyLen < len(data) { | |
return copyLen, io.ErrShortBuffer | |
} | |
if response.Data.Body == "ZGF0YTogW0RPTkVdCgo=" { | |
s.Server.Close() | |
return copyLen, io.EOF | |
} | |
return copyLen, nil | |
} | |
return 0, errors.New("response id or conversation id do not match") | |
} | |
func (s *WssToStream) Close() error { | |
return s.Server.Close() | |
} | |