Spaces:
Sleeping
Sleeping
| import fetch from 'node-fetch'; | |
| import express from 'express'; | |
| import cors from 'cors'; | |
| import dotenv from 'dotenv'; | |
| dotenv.config(); | |
| const Tokens =[]; | |
| let tokenManager; | |
| let currentIndex = 0; | |
| const CONFIG = { | |
| API: { | |
| BASE_URL: "https://partyrock.aws", | |
| API_KEY: process.env.API_KEY || "sk-123456"//自定义你自己的认证密钥,记得修改 | |
| }, | |
| RETRY: { | |
| MAX_ATTEMPTS: 1, | |
| DELAY_BASE: 1000 | |
| }, | |
| SERVER: { | |
| PORT: process.env.PORT || 3000, | |
| BODY_LIMIT: '5mb' | |
| }, | |
| MODELS: { | |
| 'claude-3-5-haiku-20241022': 'bedrock-anthropic.claude-3-5-haiku', | |
| 'claude-3-5-sonnet-20241022': 'bedrock-anthropic.claude-3-5-sonnet-v2-0', | |
| 'nova-lite-v1-0': 'bedrock-amazon.nova-lite-v1-0', | |
| 'nova-pro-v1-0': 'bedrock-amazon.nova-pro-v1-0', | |
| 'llama3-1-7b': 'bedrock-meta.llama3-1-8b-instruct-v1', | |
| 'llama3-1-70b': 'bedrock-meta.llama3-1-70b-instruct-v1', | |
| 'mistral-small': 'bedrock-mistral.mistral-small-2402-v1-0', | |
| 'mistral-large': 'bedrock-mistral.mistral-large-2407-v1-0' | |
| }, | |
| DEFAULT_HEADERS: { | |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", | |
| "Connection": "keep-alive", | |
| "Accept": "text/event-stream", | |
| "Accept-Encoding": "gzip, deflate, br, zstd", | |
| "Content-Type": "application/json", | |
| "anti-csrftoken-a2z": "", | |
| "sec-ch-ua-mobile": "?0", | |
| "origin": "https://partyrock.aws", | |
| "sec-fetch-site": "same-origin", | |
| "sec-fetch-mode": "cors", | |
| "sec-fetch-dest": "empty", | |
| "sec-ch-ua-platform": "\"Windows\"", | |
| "sec-ch-ua": "\"Google Chrome\";v=\"131\", \"Chromium\";v=\"131\", \"Not_A Brand\";v=\"24\"", | |
| "referer": "", | |
| "Cookie": "", | |
| "accept-language": "zh-CN,zh;q=0.9", | |
| "priority": "u=1, i" | |
| } | |
| }; | |
| class TokenManager { | |
| async updateCacheTokens() { | |
| CONFIG.DEFAULT_HEADERS["anti-csrftoken-a2z"] = Tokens[currentIndex]["anti_csrftoken_a2z"]; | |
| CONFIG.DEFAULT_HEADERS.Cookie = `idToken=${Tokens[currentIndex]["idToken"]}; pr_refresh_token=${Tokens[currentIndex]["pr_refresh_token"]};aws-waf-token=${Tokens[currentIndex]["aws_waf_token"]}`; | |
| CONFIG.DEFAULT_HEADERS.referer = Tokens[currentIndex]["refreshUrl"]; | |
| } | |
| async updateTokens(response) { | |
| const newCsrfToken = response.headers.get('anti-csrftoken-a2z'); | |
| if (newCsrfToken) { | |
| Tokens[currentIndex]["anti_csrftoken_a2z"] = newCsrfToken; | |
| } | |
| const cookies = response.headers.get('set-cookie'); | |
| if (cookies) { | |
| const idTokenMatch = cookies.match(/idToken=([^;]+)/); | |
| if (idTokenMatch && idTokenMatch[1]) { | |
| Tokens[currentIndex]["idToken"] = idTokenMatch[1]; | |
| } | |
| } | |
| currentIndex = (currentIndex + 1) % Tokens.length; | |
| } | |
| } | |
| class Utils { | |
| static async extractTokens(cookieString) { | |
| const tokens = {}; | |
| const cookiePairs = cookieString.split(';').map(pair => pair.trim()); | |
| cookiePairs.forEach(pair => { | |
| const splitIndex = pair.indexOf('='); | |
| const key = pair.slice(0, splitIndex).trim(); | |
| const value = pair.slice(splitIndex + 1).trim(); | |
| tokens[key] = value; | |
| }); | |
| return tokens; | |
| } | |
| // 获取数组中的随机元素 | |
| static getRandomElement(arr) { | |
| return arr[Math.floor(Math.random() * arr.length)]; | |
| } | |
| // 生成UUID | |
| static uuidv4() { | |
| return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function (c) { | |
| const r = (Math.random() * 16) | 0; | |
| const v = c === 'x' ? r : (r & 0x3) | 0x8; | |
| return v.toString(16); | |
| }); | |
| } | |
| // 生成随机十六进制字符串 | |
| static generateRandomHexString(length) { | |
| let result = ''; | |
| const characters = '0123456789ABCDEF'; | |
| for (let i = 0; i < length; i++) { | |
| result += characters.charAt(Math.floor(Math.random() * characters.length)); | |
| } | |
| return result; | |
| } | |
| } | |
| async function initializeService() { | |
| console.log('服务初始化中...'); | |
| // 遍历所有可能的 token 组 | |
| let index = 0; | |
| while (true) { | |
| const refreshUrl = process.env[`AUTH_TOKENS_${index}_REFRESH_URL`]; | |
| const anti_csrftoken_a2z = process.env[`AUTH_TOKENS_${index}_ANTI_CSRF_TOKEN`]; | |
| const cookie = process.env[`AUTH_TOKENS_${index}_COOKIE`]; | |
| if (!refreshUrl && !anti_csrftoken_a2z && !cookie) { | |
| break; | |
| } | |
| const cookies = await Utils.extractTokens(cookie); | |
| // 只有当所有属性都存在时才添加 token | |
| if (refreshUrl && anti_csrftoken_a2z && cookie) { | |
| Tokens.push({ | |
| refreshUrl, | |
| anti_csrftoken_a2z, | |
| pr_refresh_token: cookies["pr_refresh_token"], | |
| aws_waf_token: cookies["aws-waf-token"], | |
| idToken: cookies["idToken"] | |
| }); | |
| } | |
| index++; | |
| } | |
| tokenManager = new TokenManager(); | |
| } | |
| await initializeService(); | |
| class ApiClient { | |
| constructor(modelId) { | |
| console.log(modelId); | |
| if (!CONFIG.MODELS[modelId]) { | |
| throw new Error(`不支持的模型: ${modelId}`); | |
| } | |
| this.modelId = CONFIG.MODELS[modelId]; | |
| } | |
| processMessageContent(content) { | |
| if (typeof content === 'string') return content; | |
| if (Array.isArray(content)) { | |
| return content | |
| .map(item => item.text) | |
| .join('\n'); | |
| } | |
| if (typeof content === 'object') return content.text || null; | |
| return null; | |
| } | |
| //合并相同role的消息 | |
| async transformMessages(request) { | |
| const mergedMessages = await request.messages.reduce(async (accPromise, current) => { | |
| const acc = await accPromise; | |
| const lastMessage = acc[acc.length - 1]; | |
| if (lastMessage && lastMessage.role == "system") { | |
| lastMessage.role = "user" | |
| } | |
| if (current && current.role == "system") { | |
| current.role = "user" | |
| } | |
| const currentContent = this.processMessageContent(current.content); | |
| if (currentContent === null) return acc; | |
| if (lastMessage && current && (lastMessage.role == current.role)) { | |
| const lastContent = this.processMessageContent(lastMessage.content); | |
| if (lastContent !== null) { | |
| lastMessage.content = [ | |
| { | |
| "text": `${lastContent}\r\n${currentContent}` | |
| } | |
| ]; | |
| return acc; | |
| } | |
| } | |
| current.content = [ | |
| { | |
| "text": currentContent | |
| } | |
| ] | |
| acc.push(current); | |
| return acc; | |
| }, Promise.resolve([])); | |
| // 处理请求参数 | |
| let topP = request.top_p || 0.5; | |
| let temperature = request.temperature || 0.95; | |
| if (topP >= 1) { | |
| topP = 1; | |
| } | |
| if (temperature >= 1) { | |
| temperature = 1; | |
| } | |
| const extractPartyRockId = url => url.match(/https:\/\/partyrock\.aws\/u\/[^/]+\/([^/]+)/)?.[1]; | |
| console.log(CONFIG.DEFAULT_HEADERS.referer); | |
| const requestPayload = { | |
| "messages": mergedMessages, | |
| "modelName": this.modelId, | |
| "context": { | |
| "type": "chat-widget", | |
| "appId": extractPartyRockId(CONFIG.DEFAULT_HEADERS.referer) | |
| }, | |
| "options": { | |
| "temperature": temperature, | |
| "topP": topP | |
| }, | |
| "apiVersion": 3 | |
| } | |
| return requestPayload; | |
| } | |
| } | |
| class MessageProcessor { | |
| static createChatResponse(message, model, isStream = false) { | |
| const baseResponse = { | |
| id: `chatcmpl-${Utils.uuidv4()}`, | |
| created: Math.floor(Date.now() / 1000), | |
| model: model | |
| }; | |
| if (isStream) { | |
| return { | |
| ...baseResponse, | |
| object: 'chat.completion.chunk', | |
| choices: [{ | |
| index: 0, | |
| delta: { content: message } | |
| }] | |
| }; | |
| } | |
| return { | |
| ...baseResponse, | |
| object: 'chat.completion', | |
| choices: [{ | |
| index: 0, | |
| message: { | |
| role: 'assistant', | |
| content: message | |
| }, | |
| finish_reason: 'stop' | |
| }], | |
| usage: null | |
| }; | |
| } | |
| } | |
| class ResponseHandler { | |
| static async handleStreamResponse(response, model, res) { | |
| res.setHeader('Content-Type', 'text/event-stream'); | |
| res.setHeader('Cache-Control', 'no-cache'); | |
| res.setHeader('Connection', 'keep-alive'); | |
| try { | |
| const stream = response.body; | |
| let buffer = ''; | |
| let decoder = new TextDecoder('utf-8'); | |
| stream.on('data', (chunk) => { | |
| buffer += decoder.decode(chunk, { stream: true }); | |
| const lines = buffer.split('\n'); | |
| buffer = lines.pop() || ''; | |
| for (const line of lines) { | |
| if (!line) continue; | |
| const trimmedLine = line.trim(); | |
| if (trimmedLine && trimmedLine.startsWith('data: ')) { | |
| const data = trimmedLine.substring(6); | |
| if (!data) continue; | |
| try { | |
| const json = JSON.parse(data); | |
| if (json?.text) { | |
| var content = json.text; | |
| const responseData = MessageProcessor.createChatResponse(content, model, true); | |
| res.write(`data: ${JSON.stringify(responseData)}\n\n`); | |
| } | |
| } catch (error) { | |
| console.error('JSON解析错误:', error); | |
| } | |
| } | |
| } | |
| }); | |
| stream.on('end', () => { | |
| res.write('data: [DONE]\n\n'); | |
| res.end(); | |
| }); | |
| stream.on('error', (error) => { | |
| console.error('流处理错误:', error); | |
| res.write('data: [DONE]\n\n'); | |
| res.end(); | |
| }); | |
| } catch (error) { | |
| console.error('处理响应错误:', error); | |
| res.write('data: [DONE]\n\n'); | |
| res.end(); | |
| } | |
| } | |
| static async handleNormalResponse(response, model, res) { | |
| const text = await response.text(); | |
| const lines = text.split("\n"); | |
| let fullResponse = ''; | |
| for (let line of lines) { | |
| line = line.trim(); | |
| if (line) { | |
| if (line.startsWith('data: ')) { | |
| let data = line.substring(6); | |
| if (data === '[DONE]') break; | |
| try { | |
| let json = JSON.parse(data) | |
| if (json?.text) { | |
| fullResponse += json.text; | |
| } | |
| } catch (error) { | |
| console.log("json解析错误"); | |
| continue | |
| } | |
| } | |
| } | |
| } | |
| const responseData = MessageProcessor.createChatResponse(fullResponse, model); | |
| res.json(responseData); | |
| } | |
| } | |
| // Express 应用设置 | |
| const app = express(); | |
| app.use(express.json({ limit: CONFIG.SERVER.BODY_LIMIT })); | |
| app.use(express.urlencoded({ extended: true, limit: CONFIG.SERVER.BODY_LIMIT })); | |
| app.use(cors({ | |
| origin: '*', | |
| methods: ['GET', 'POST', 'OPTIONS'], | |
| allowedHeaders: ['*'] | |
| })); | |
| // 路由处理 | |
| app.get('/hf/v1/models', (req, res) => { | |
| res.json({ | |
| object: "list", | |
| data: Object.keys(CONFIG.MODELS).map((model, index) => ({ | |
| id: model, | |
| object: "model", | |
| created: Math.floor(Date.now() / 1000), | |
| owned_by: "partyrock", | |
| })) | |
| }); | |
| }); | |
| app.post('/hf/v1/chat/completions', async (req, res) => { | |
| try { | |
| const authToken = req.headers.authorization?.replace('Bearer ', ''); | |
| if (authToken !== CONFIG.API.API_KEY) { | |
| return res.status(401).json({ error: "Unauthorized" }); | |
| } | |
| await tokenManager.updateCacheTokens(); | |
| const apiClient = new ApiClient(req.body.model); | |
| const requestPayload = await apiClient.transformMessages(req.body); | |
| let retryCount = 0; | |
| while (retryCount < CONFIG.RETRY.MAX_ATTEMPTS) { | |
| try { | |
| console.log("开始请求"); | |
| //发送请求 | |
| var response = await fetch(`https://partyrock.aws/stream/getCompletion`, { | |
| method: "POST", | |
| headers: { | |
| ...CONFIG.DEFAULT_HEADERS | |
| }, | |
| body: JSON.stringify(requestPayload) | |
| }); | |
| await tokenManager.updateTokens(response); | |
| if (response.status == 200) { | |
| console.log("请求成功"); | |
| break; // 如果请求成功,跳出重试循环 | |
| } | |
| if (response.status != 200) { | |
| console.error(JSON.stringify(await response.text(), null, 2)); | |
| } | |
| retryCount++; | |
| if (retryCount >= CONFIG.RETRY.MAX_ATTEMPTS) { | |
| throw new Error(`上游服务请求失败! status: ${response.status}`); | |
| } | |
| // 等待一段时间后重试 | |
| await new Promise(resolve => setTimeout(resolve, CONFIG.RETRY.DELAY_BASE * retryCount)); | |
| } catch (error) { | |
| retryCount++; | |
| if (retryCount >= CONFIG.RETRY.MAX_ATTEMPTS) { | |
| throw error; | |
| } | |
| // 等待一段时间后重试 | |
| await new Promise(resolve => setTimeout(resolve, CONFIG.RETRY.DELAY_BASE * retryCount)); | |
| } | |
| } | |
| if (req.body.stream) { | |
| await ResponseHandler.handleStreamResponse(response, req.body.model, res); | |
| } else { | |
| await ResponseHandler.handleNormalResponse(response, req.body.model, res); | |
| } | |
| } catch (error) { | |
| res.status(500).json({ | |
| error: { | |
| message: error.message, | |
| type: 'server_error', | |
| param: null, | |
| code: error.code || null | |
| } | |
| }); | |
| } | |
| }); | |
| app.use((req, res) => { | |
| res.status(404).json({ message: "API服务运行正常" }); | |
| }); | |
| // 启动服务器 | |
| app.listen(CONFIG.SERVER.PORT, () => { | |
| console.log(`服务器运行在端口 ${CONFIG.SERVER.PORT} `); | |
| }); |