|
import os |
|
import json |
|
import uuid |
|
from datetime import datetime |
|
from flask import Flask, request, Response, jsonify |
|
import socketio |
|
import requests |
|
import logging |
|
from threading import Event |
|
import re |
|
|
|
app = Flask(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
API_KEY = os.environ.get('PPLX_KEY') |
|
|
|
|
|
proxy_url = os.environ.get('PROXY_URL') |
|
|
|
|
|
if proxy_url: |
|
proxies = { |
|
'http': proxy_url, |
|
'https': proxy_url |
|
} |
|
transport = requests.Session() |
|
transport.proxies.update(proxies) |
|
else: |
|
transport = None |
|
|
|
sio = socketio.Client(http_session=transport, logger=True, engineio_logger=True) |
|
|
|
|
|
connect_opts = { |
|
'transports': ['websocket', 'polling'], |
|
} |
|
|
|
|
|
sio_opts = { |
|
'extraHeaders': { |
|
'Cookie': os.environ.get('PPLX_COOKIE'), |
|
'User-Agent': os.environ.get('USER_AGENT'), |
|
'Accept': '*/*', |
|
'priority': 'u=1, i', |
|
'Referer': 'https://www.perplexity.ai/', |
|
} |
|
} |
|
|
|
def log_request(ip, route, status): |
|
timestamp = datetime.now().isoformat() |
|
logging.info(f"{timestamp} - {ip} - {route} - {status}") |
|
|
|
def validate_api_key(): |
|
api_key = request.headers.get('x-api-key') |
|
if api_key != API_KEY: |
|
log_request(request.remote_addr, request.path, 401) |
|
return jsonify({"error": "Invalid API key"}), 401 |
|
return None |
|
|
|
def normalize_content(content): |
|
""" |
|
递归处理 msg['content'],确保其为字符串。 |
|
如果 content 是字典或列表,将其转换为字符串。 |
|
""" |
|
if isinstance(content, str): |
|
return content |
|
elif isinstance(content, dict): |
|
|
|
return json.dumps(content, ensure_ascii=False) |
|
elif isinstance(content, list): |
|
|
|
return " ".join([normalize_content(item) for item in content]) |
|
else: |
|
|
|
return "" |
|
|
|
def calculate_tokens(text): |
|
""" |
|
改进的 token 计算方法。 |
|
- 对于英文和有空格的文本,使用空格分词。 |
|
- 对于中文等没有空格的文本,使用字符级分词。 |
|
""" |
|
|
|
if re.search(r'[^\x00-\x7F]', text): |
|
|
|
return len(text) |
|
else: |
|
|
|
tokens = text.split() |
|
return len(tokens) |
|
|
|
@app.route('/') |
|
def root(): |
|
log_request(request.remote_addr, request.path, 200) |
|
return jsonify({ |
|
"message": "Welcome to the Perplexity AI Proxy API", |
|
"endpoints": { |
|
"/ai/v1/messages": { |
|
"method": "POST", |
|
"description": "Send a message to the AI", |
|
"headers": { |
|
"x-api-key": "Your API key (required)", |
|
"Content-Type": "application/json" |
|
}, |
|
"body": { |
|
"messages": "Array of message objects", |
|
"stream": "Boolean (true for streaming response)", |
|
"model": "Model to be used (optional, defaults to claude-3-opus-20240229)" |
|
} |
|
} |
|
} |
|
}) |
|
|
|
@app.route('/ai/v1/messages', methods=['POST']) |
|
def messages(): |
|
auth_error = validate_api_key() |
|
if auth_error: |
|
return auth_error |
|
|
|
try: |
|
json_body = request.json |
|
model = json_body.get('model', 'claude-3-opus-20240229') |
|
stream = json_body.get('stream', True) |
|
|
|
|
|
previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']]) |
|
|
|
|
|
input_tokens = calculate_tokens(previous_messages) |
|
|
|
msg_id = str(uuid.uuid4()) |
|
response_event = Event() |
|
response_text = [] |
|
|
|
if not stream: |
|
|
|
return handle_non_stream(previous_messages, msg_id, model, input_tokens) |
|
|
|
|
|
log_request(request.remote_addr, request.path, 200) |
|
|
|
def generate(): |
|
yield create_event("message_start", { |
|
"type": "message_start", |
|
"message": { |
|
"id": msg_id, |
|
"type": "message", |
|
"role": "assistant", |
|
"content": [], |
|
"model": model, |
|
"stop_reason": None, |
|
"stop_sequence": None, |
|
"usage": {"input_tokens": input_tokens, "output_tokens": 1}, |
|
}, |
|
}) |
|
yield create_event("content_block_start", {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}) |
|
yield create_event("ping", {"type": "ping"}) |
|
|
|
def on_connect(): |
|
logging.info("Connected to Perplexity AI") |
|
emit_data = { |
|
"version": "2.9", |
|
"source": "default", |
|
"attachments": [], |
|
"language": "en-GB", |
|
"timezone": "Europe/London", |
|
"mode": "concise", |
|
"is_related_query": False, |
|
"is_default_related_query": False, |
|
"visitor_id": str(uuid.uuid4()), |
|
"frontend_context_uuid": str(uuid.uuid4()), |
|
"prompt_source": "user", |
|
"query_source": "home" |
|
} |
|
sio.emit('perplexity_ask', (previous_messages, emit_data)) |
|
|
|
def on_query_progress(data): |
|
nonlocal response_text |
|
if 'text' in data: |
|
text = json.loads(data['text']) |
|
chunk = text['chunks'][-1] if text['chunks'] else None |
|
if chunk: |
|
response_text.append(chunk) |
|
|
|
|
|
if data.get('final', False): |
|
response_event.set() |
|
|
|
def on_query_complete(data): |
|
response_event.set() |
|
|
|
def on_disconnect(): |
|
logging.info("Disconnected from Perplexity AI") |
|
response_event.set() |
|
|
|
def on_connect_error(data): |
|
logging.error(f"Connection error: {data}") |
|
response_text.append(f"Error connecting to Perplexity AI: {data}") |
|
response_event.set() |
|
|
|
sio.on('connect', on_connect) |
|
sio.on('query_progress', on_query_progress) |
|
sio.on('query_complete', on_query_complete) |
|
sio.on('disconnect', on_disconnect) |
|
sio.on('connect_error', on_connect_error) |
|
|
|
try: |
|
sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders']) |
|
|
|
while not response_event.is_set(): |
|
sio.sleep(0.1) |
|
while response_text: |
|
chunk = response_text.pop(0) |
|
yield create_event("content_block_delta", { |
|
"type": "content_block_delta", |
|
"index": 0, |
|
"delta": {"type": "text_delta", "text": chunk}, |
|
}) |
|
|
|
except Exception as e: |
|
logging.error(f"Error during socket connection: {str(e)}") |
|
yield create_event("content_block_delta", { |
|
"type": "content_block_delta", |
|
"index": 0, |
|
"delta": {"type": "text_delta", "text": f"Error during socket connection: {str(e)}"}, |
|
}) |
|
finally: |
|
if sio.connected: |
|
sio.disconnect() |
|
|
|
|
|
output_tokens = calculate_tokens(''.join(response_text)) |
|
|
|
yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0}) |
|
yield create_event("message_delta", { |
|
"type": "message_delta", |
|
"delta": {"stop_reason": "end_turn", "stop_sequence": None}, |
|
"usage": {"input_tokens": input_tokens, "output_tokens": output_tokens}, |
|
}) |
|
yield create_event("message_stop", {"type": "message_stop"}) |
|
|
|
return Response(generate(), content_type='text/event-stream') |
|
|
|
except Exception as e: |
|
logging.error(f"Request error: {str(e)}") |
|
log_request(request.remote_addr, request.path, 400) |
|
return jsonify({"error": str(e)}), 400 |
|
|
|
def handle_non_stream(previous_messages, msg_id, model, input_tokens): |
|
""" |
|
处理 stream 为 false 的情况,返回完整的响应。 |
|
""" |
|
try: |
|
response_event = Event() |
|
response_text = [] |
|
|
|
def on_connect(): |
|
logging.info("Connected to Perplexity AI") |
|
emit_data = { |
|
"version": "2.9", |
|
"source": "default", |
|
"attachments": [], |
|
"language": "en-GB", |
|
"timezone": "Europe/London", |
|
"mode": "concise", |
|
"is_related_query": False, |
|
"is_default_related_query": False, |
|
"visitor_id": str(uuid.uuid4()), |
|
"frontend_context_uuid": str(uuid.uuid4()), |
|
"prompt_source": "user", |
|
"query_source": "home" |
|
} |
|
sio.emit('perplexity_ask', (previous_messages, emit_data)) |
|
|
|
def on_query_progress(data): |
|
nonlocal response_text |
|
if 'text' in data: |
|
text = json.loads(data['text']) |
|
chunk = text['chunks'][-1] if text['chunks'] else None |
|
if chunk: |
|
response_text.append(chunk) |
|
|
|
|
|
if data.get('final', False): |
|
response_event.set() |
|
|
|
def on_disconnect(): |
|
logging.info("Disconnected from Perplexity AI") |
|
response_event.set() |
|
|
|
def on_connect_error(data): |
|
logging.error(f"Connection error: {data}") |
|
response_text.append(f"Error connecting to Perplexity AI: {data}") |
|
response_event.set() |
|
|
|
sio.on('connect', on_connect) |
|
sio.on('query_progress', on_query_progress) |
|
sio.on('disconnect', on_disconnect) |
|
sio.on('connect_error', on_connect_error) |
|
|
|
sio.connect('wss://www.perplexity.ai/', **connect_opts, headers=sio_opts['extraHeaders']) |
|
|
|
|
|
response_event.wait(timeout=30) |
|
|
|
|
|
output_tokens = calculate_tokens(''.join(response_text)) |
|
|
|
|
|
full_response = { |
|
"content": [{"text": ''.join(response_text), "type": "text"}], |
|
"id": msg_id, |
|
"model": model, |
|
"role": "assistant", |
|
"stop_reason": "end_turn", |
|
"stop_sequence": None, |
|
"type": "message", |
|
"usage": { |
|
"input_tokens": input_tokens, |
|
"output_tokens": output_tokens, |
|
}, |
|
} |
|
return Response(json.dumps(full_response, ensure_ascii=False), content_type='application/json') |
|
|
|
except Exception as e: |
|
logging.error(f"Error during socket connection: {str(e)}") |
|
return jsonify({"error": str(e)}), 500 |
|
finally: |
|
if sio.connected: |
|
sio.disconnect() |
|
|
|
@app.errorhandler(404) |
|
def not_found(error): |
|
log_request(request.remote_addr, request.path, 404) |
|
return "Not Found", 404 |
|
|
|
@app.errorhandler(500) |
|
def server_error(error): |
|
logging.error(f"Server error: {str(error)}") |
|
log_request(request.remote_addr, request.path, 500) |
|
return "Something broke!", 500 |
|
|
|
def create_event(event, data): |
|
if isinstance(data, dict): |
|
data = json.dumps(data, ensure_ascii=False) |
|
return f"event: {event}\ndata: {data}\n\n" |
|
|
|
if __name__ == '__main__': |
|
port = int(os.environ.get('PORT', 8081)) |
|
logging.info(f"Perplexity proxy listening on port {port}") |
|
if not API_KEY: |
|
logging.warning("Warning: PPLX_KEY environment variable is not set. API key validation will fail.") |
|
app.run(host='0.0.0.0', port=port) |
|
|