partyrock2api / index.js
yxmiler's picture
Update index.js
71b614e verified
import fetch from 'node-fetch';
import express from 'express';
import cors from 'cors';
import dotenv from 'dotenv';
dotenv.config();
const Tokens =[];
let tokenManager;
let currentIndex = 0;
const CONFIG = {
API: {
BASE_URL: "https://partyrock.aws",
API_KEY: process.env.API_KEY || "sk-123456"//自定义你自己的认证密钥,记得修改
},
RETRY: {
MAX_ATTEMPTS: 1,
DELAY_BASE: 1000
},
SERVER: {
PORT: process.env.PORT || 3000,
BODY_LIMIT: '5mb'
},
MODELS: {
'claude-3-5-haiku-20241022': 'bedrock-anthropic.claude-3-5-haiku',
'claude-3-5-sonnet-20241022': 'bedrock-anthropic.claude-3-5-sonnet-v2-0',
'nova-lite-v1-0': 'bedrock-amazon.nova-lite-v1-0',
'nova-pro-v1-0': 'bedrock-amazon.nova-pro-v1-0',
'llama3-1-7b': 'bedrock-meta.llama3-1-8b-instruct-v1',
'llama3-1-70b': 'bedrock-meta.llama3-1-70b-instruct-v1',
'mistral-small': 'bedrock-mistral.mistral-small-2402-v1-0',
'mistral-large': 'bedrock-mistral.mistral-large-2407-v1-0'
},
DEFAULT_HEADERS: {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
"Connection": "keep-alive",
"Accept": "text/event-stream",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Content-Type": "application/json",
"anti-csrftoken-a2z": "",
"sec-ch-ua-mobile": "?0",
"origin": "https://partyrock.aws",
"sec-fetch-site": "same-origin",
"sec-fetch-mode": "cors",
"sec-fetch-dest": "empty",
"sec-ch-ua-platform": "\"Windows\"",
"sec-ch-ua": "\"Google Chrome\";v=\"131\", \"Chromium\";v=\"131\", \"Not_A Brand\";v=\"24\"",
"referer": "",
"Cookie": "",
"accept-language": "zh-CN,zh;q=0.9",
"priority": "u=1, i"
}
};
class TokenManager {
async updateCacheTokens() {
CONFIG.DEFAULT_HEADERS["anti-csrftoken-a2z"] = Tokens[currentIndex]["anti_csrftoken_a2z"];
CONFIG.DEFAULT_HEADERS.Cookie = `idToken=${Tokens[currentIndex]["idToken"]}; pr_refresh_token=${Tokens[currentIndex]["pr_refresh_token"]};aws-waf-token=${Tokens[currentIndex]["aws_waf_token"]}`;
CONFIG.DEFAULT_HEADERS.referer = Tokens[currentIndex]["refreshUrl"];
}
async updateTokens(response) {
const newCsrfToken = response.headers.get('anti-csrftoken-a2z');
if (newCsrfToken) {
Tokens[currentIndex]["anti_csrftoken_a2z"] = newCsrfToken;
}
const cookies = response.headers.get('set-cookie');
if (cookies) {
const idTokenMatch = cookies.match(/idToken=([^;]+)/);
if (idTokenMatch && idTokenMatch[1]) {
Tokens[currentIndex]["idToken"] = idTokenMatch[1];
}
}
currentIndex = (currentIndex + 1) % Tokens.length;
}
}
class Utils {
static async extractTokens(cookieString) {
const tokens = {};
const cookiePairs = cookieString.split(';').map(pair => pair.trim());
cookiePairs.forEach(pair => {
const splitIndex = pair.indexOf('=');
const key = pair.slice(0, splitIndex).trim();
const value = pair.slice(splitIndex + 1).trim();
tokens[key] = value;
});
return tokens;
}
// 获取数组中的随机元素
static getRandomElement(arr) {
return arr[Math.floor(Math.random() * arr.length)];
}
// 生成UUID
static uuidv4() {
return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function (c) {
const r = (Math.random() * 16) | 0;
const v = c === 'x' ? r : (r & 0x3) | 0x8;
return v.toString(16);
});
}
// 生成随机十六进制字符串
static generateRandomHexString(length) {
let result = '';
const characters = '0123456789ABCDEF';
for (let i = 0; i < length; i++) {
result += characters.charAt(Math.floor(Math.random() * characters.length));
}
return result;
}
}
async function initializeService() {
console.log('服务初始化中...');
// 遍历所有可能的 token 组
let index = 0;
while (true) {
const refreshUrl = process.env[`AUTH_TOKENS_${index}_REFRESH_URL`];
const anti_csrftoken_a2z = process.env[`AUTH_TOKENS_${index}_ANTI_CSRF_TOKEN`];
const cookie = process.env[`AUTH_TOKENS_${index}_COOKIE`];
if (!refreshUrl && !anti_csrftoken_a2z && !cookie) {
break;
}
const cookies = await Utils.extractTokens(cookie);
// 只有当所有属性都存在时才添加 token
if (refreshUrl && anti_csrftoken_a2z && cookie) {
Tokens.push({
refreshUrl,
anti_csrftoken_a2z,
pr_refresh_token: cookies["pr_refresh_token"],
aws_waf_token: cookies["aws-waf-token"],
idToken: cookies["idToken"]
});
}
index++;
}
tokenManager = new TokenManager();
}
await initializeService();
class ApiClient {
constructor(modelId) {
console.log(modelId);
if (!CONFIG.MODELS[modelId]) {
throw new Error(`不支持的模型: ${modelId}`);
}
this.modelId = CONFIG.MODELS[modelId];
}
processMessageContent(content) {
if (typeof content === 'string') return content;
if (Array.isArray(content)) {
return content
.map(item => item.text)
.join('\n');
}
if (typeof content === 'object') return content.text || null;
return null;
}
//合并相同role的消息
async transformMessages(request) {
const mergedMessages = await request.messages.reduce(async (accPromise, current) => {
const acc = await accPromise;
const lastMessage = acc[acc.length - 1];
if (lastMessage && lastMessage.role == "system") {
lastMessage.role = "user"
}
if (current && current.role == "system") {
current.role = "user"
}
const currentContent = this.processMessageContent(current.content);
if (currentContent === null) return acc;
if (lastMessage && current && (lastMessage.role == current.role)) {
const lastContent = this.processMessageContent(lastMessage.content);
if (lastContent !== null) {
lastMessage.content = [
{
"text": `${lastContent}\r\n${currentContent}`
}
];
return acc;
}
}
current.content = [
{
"text": currentContent
}
]
acc.push(current);
return acc;
}, Promise.resolve([]));
// 处理请求参数
let topP = request.top_p || 0.5;
let temperature = request.temperature || 0.95;
if (topP >= 1) {
topP = 1;
}
if (temperature >= 1) {
temperature = 1;
}
const extractPartyRockId = url => url.match(/https:\/\/partyrock\.aws\/u\/[^/]+\/([^/]+)/)?.[1];
console.log(CONFIG.DEFAULT_HEADERS.referer);
const requestPayload = {
"messages": mergedMessages,
"modelName": this.modelId,
"context": {
"type": "chat-widget",
"appId": extractPartyRockId(CONFIG.DEFAULT_HEADERS.referer)
},
"options": {
"temperature": temperature,
"topP": topP
},
"apiVersion": 3
}
return requestPayload;
}
}
class MessageProcessor {
static createChatResponse(message, model, isStream = false) {
const baseResponse = {
id: `chatcmpl-${Utils.uuidv4()}`,
created: Math.floor(Date.now() / 1000),
model: model
};
if (isStream) {
return {
...baseResponse,
object: 'chat.completion.chunk',
choices: [{
index: 0,
delta: { content: message }
}]
};
}
return {
...baseResponse,
object: 'chat.completion',
choices: [{
index: 0,
message: {
role: 'assistant',
content: message
},
finish_reason: 'stop'
}],
usage: null
};
}
}
class ResponseHandler {
static async handleStreamResponse(response, model, res) {
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
try {
const stream = response.body;
let buffer = '';
let decoder = new TextDecoder('utf-8');
stream.on('data', (chunk) => {
buffer += decoder.decode(chunk, { stream: true });
const lines = buffer.split('\n');
buffer = lines.pop() || '';
for (const line of lines) {
if (!line) continue;
const trimmedLine = line.trim();
if (trimmedLine && trimmedLine.startsWith('data: ')) {
const data = trimmedLine.substring(6);
if (!data) continue;
try {
const json = JSON.parse(data);
if (json?.text) {
var content = json.text;
const responseData = MessageProcessor.createChatResponse(content, model, true);
res.write(`data: ${JSON.stringify(responseData)}\n\n`);
}
} catch (error) {
console.error('JSON解析错误:', error);
}
}
}
});
stream.on('end', () => {
res.write('data: [DONE]\n\n');
res.end();
});
stream.on('error', (error) => {
console.error('流处理错误:', error);
res.write('data: [DONE]\n\n');
res.end();
});
} catch (error) {
console.error('处理响应错误:', error);
res.write('data: [DONE]\n\n');
res.end();
}
}
static async handleNormalResponse(response, model, res) {
const text = await response.text();
const lines = text.split("\n");
let fullResponse = '';
for (let line of lines) {
line = line.trim();
if (line) {
if (line.startsWith('data: ')) {
let data = line.substring(6);
if (data === '[DONE]') break;
try {
let json = JSON.parse(data)
if (json?.text) {
fullResponse += json.text;
}
} catch (error) {
console.log("json解析错误");
continue
}
}
}
}
const responseData = MessageProcessor.createChatResponse(fullResponse, model);
res.json(responseData);
}
}
// Express 应用设置
const app = express();
app.use(express.json({ limit: CONFIG.SERVER.BODY_LIMIT }));
app.use(express.urlencoded({ extended: true, limit: CONFIG.SERVER.BODY_LIMIT }));
app.use(cors({
origin: '*',
methods: ['GET', 'POST', 'OPTIONS'],
allowedHeaders: ['*']
}));
// 路由处理
app.get('/hf/v1/models', (req, res) => {
res.json({
object: "list",
data: Object.keys(CONFIG.MODELS).map((model, index) => ({
id: model,
object: "model",
created: Math.floor(Date.now() / 1000),
owned_by: "partyrock",
}))
});
});
app.post('/hf/v1/chat/completions', async (req, res) => {
try {
const authToken = req.headers.authorization?.replace('Bearer ', '');
if (authToken !== CONFIG.API.API_KEY) {
return res.status(401).json({ error: "Unauthorized" });
}
await tokenManager.updateCacheTokens();
const apiClient = new ApiClient(req.body.model);
const requestPayload = await apiClient.transformMessages(req.body);
let retryCount = 0;
while (retryCount < CONFIG.RETRY.MAX_ATTEMPTS) {
try {
console.log("开始请求");
//发送请求
var response = await fetch(`https://partyrock.aws/stream/getCompletion`, {
method: "POST",
headers: {
...CONFIG.DEFAULT_HEADERS
},
body: JSON.stringify(requestPayload)
});
await tokenManager.updateTokens(response);
if (response.status == 200) {
console.log("请求成功");
break; // 如果请求成功,跳出重试循环
}
if (response.status != 200) {
console.error(JSON.stringify(await response.text(), null, 2));
}
retryCount++;
if (retryCount >= CONFIG.RETRY.MAX_ATTEMPTS) {
throw new Error(`上游服务请求失败! status: ${response.status}`);
}
// 等待一段时间后重试
await new Promise(resolve => setTimeout(resolve, CONFIG.RETRY.DELAY_BASE * retryCount));
} catch (error) {
retryCount++;
if (retryCount >= CONFIG.RETRY.MAX_ATTEMPTS) {
throw error;
}
// 等待一段时间后重试
await new Promise(resolve => setTimeout(resolve, CONFIG.RETRY.DELAY_BASE * retryCount));
}
}
if (req.body.stream) {
await ResponseHandler.handleStreamResponse(response, req.body.model, res);
} else {
await ResponseHandler.handleNormalResponse(response, req.body.model, res);
}
} catch (error) {
res.status(500).json({
error: {
message: error.message,
type: 'server_error',
param: null,
code: error.code || null
}
});
}
});
app.use((req, res) => {
res.status(404).json({ message: "API服务运行正常" });
});
// 启动服务器
app.listen(CONFIG.SERVER.PORT, () => {
console.log(`服务器运行在端口 ${CONFIG.SERVER.PORT} `);
});