Spaces:
Running
Running
# main.py | |
import asyncio | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
import uvicorn | |
# 创建 FastAPI 应用实例 | |
app = FastAPI() | |
async def tunnel(websocket: WebSocket): | |
await websocket.accept() | |
tcp_writer = None | |
tcp_reader = None | |
closed = False | |
async def safe_close(): | |
nonlocal closed | |
if not closed: | |
closed = True | |
try: | |
await websocket.close() | |
except Exception: | |
pass | |
if tcp_writer: | |
try: | |
tcp_writer.close() | |
await tcp_writer.wait_closed() | |
except Exception: | |
pass | |
try: | |
# --------------------------- | |
# 1. 等待客户端发来的 CONNECT 请求 | |
# --------------------------- | |
# CONNECT 请求格式示例: | |
# CONNECT destHost:destPort HTTP/1.1\r\nHost: destHost:destPort\r\n\r\n | |
request_text = await websocket.receive_text() | |
lines = request_text.splitlines() | |
if not lines: | |
await websocket.send_text("HTTP/1.1 400 Bad Request\r\n\r\n") | |
await safe_close() | |
return | |
# 解析第一行 | |
first_line = lines[0].strip() | |
parts = first_line.split() | |
if len(parts) < 3 or parts[0].upper() != "CONNECT": | |
await websocket.send_text("HTTP/1.1 400 Bad Request\r\n\r\n") | |
await safe_close() | |
return | |
# 从第二个字段中获取目标主机及端口,如 destHost:destPort | |
dest = parts[1] | |
if ":" not in dest: | |
await websocket.send_text("HTTP/1.1 400 Bad Request\r\n\r\n") | |
await safe_close() | |
return | |
dest_parts = dest.split(":", 1) | |
dest_host = dest_parts[0] | |
try: | |
dest_port = int(dest_parts[1]) | |
except Exception: | |
await websocket.send_text("HTTP/1.1 400 Bad Request\r\n\r\n") | |
await safe_close() | |
return | |
# --------------------------- | |
# 2. 建立到目标主机的 TCP 连接 | |
# --------------------------- | |
try: | |
tcp_reader, tcp_writer = await asyncio.open_connection(dest_host, dest_port) | |
except Exception as e: | |
err_msg = f"HTTP/1.1 502 Bad Gateway\r\n\r\n无法连接 {dest_host}:{dest_port},错误:{e}" | |
await websocket.send_text(err_msg) | |
await safe_close() | |
return | |
# --------------------------- | |
# 3. 向客户端返回 200 成功响应 | |
# --------------------------- | |
await websocket.send_text("HTTP/1.1 200 Connection Established\r\n\r\n") | |
# --------------------------- | |
# 4. 双向数据转发 | |
# --------------------------- | |
async def tcp_to_ws(): | |
""" | |
从 TCP 连接中读取数据,通过 WebSocket 以二进制方式发送给客户端 | |
""" | |
try: | |
while not closed: | |
data = await tcp_reader.read(1024) | |
if not data: | |
break | |
await websocket.send_bytes(data) | |
except Exception as e: | |
# 读取异常或对方关闭连接时退出 | |
print("tcp_to_ws 异常:", e) | |
finally: | |
await safe_close() | |
async def ws_to_tcp(): | |
""" | |
从客户端通过 WebSocket 发送的数据写入 TCP 连接 | |
""" | |
try: | |
while not closed: | |
message = await websocket.receive() | |
# 接收到的数据可能是文本或二进制,这里尽量以二进制方式处理 | |
if "bytes" in message: | |
tcp_writer.write(message["bytes"]) | |
await tcp_writer.drain() | |
elif "text" in message: | |
# 若收到文本数据,则转换为 bytes(可能只在握手阶段出现) | |
tcp_writer.write(message["text"].encode("utf-8")) | |
await tcp_writer.drain() | |
elif message.get("type") == "websocket.disconnect": | |
break | |
except Exception as e: | |
print("ws_to_tcp 异常:", e) | |
finally: | |
await safe_close() | |
# 并发执行数据转发任务,任一方向关闭则结束隧道 | |
await asyncio.gather(tcp_to_ws(), ws_to_tcp()) | |
except WebSocketDisconnect: | |
print("WebSocketDisconnect") | |
await safe_close() | |
except Exception as e: | |
print("WebSocket 隧道处理异常:", e) | |
await safe_close() | |
finally: | |
# 关闭连接 | |
await safe_close() | |
# --------------------------- | |
# 启动服务器:监听 0.0.0.0:7860 | |
# --------------------------- | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |