File size: 3,049 Bytes
2a528ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import argparse
import asyncio
import json
import os
import traceback
import urllib.request
from EdgeGPT import Chatbot
from aiohttp import web
public_dir = '/public'
async def process_message(user_message, context, _U, locale):
chatbot = None
if _U:
cookies = loaded_cookies + [{"name": "_U", "value": _U}]
cookies = loaded_cookies
chatbot = await Chatbot.create(cookies=cookies, proxy=args.proxy)
async for _, response in chatbot.ask_stream(prompt=user_message, conversation_style="creative", raw=True,
webpage_context=context, search_result=True, locale=locale):
yield response
yield {"type": "error", "error": traceback.format_exc()}
if chatbot:
await chatbot.close()
async def http_handler(request):
file_path = request.path
if file_path == "/":
file_path = "/index.html"
full_path = os.path.realpath('.' + public_dir + file_path)
if not full_path.startswith(os.path.realpath('.' + public_dir)):
raise web.HTTPForbidden()
response = web.FileResponse(full_path)
response.headers['Cache-Control'] = 'no-store'
return response
async def websocket_handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
if msg.type == web.WSMsgType.TEXT:
request = json.loads(
user_message = request['message']
context = request['context']
locale = request['locale']
_U = request.get('_U')
async for response in process_message(user_message, context, _U, locale=locale):
await ws.send_json(response)
return ws
async def main(host, port):
app = web.Application()
app.router.add_get('/ws/', websocket_handler)
app.router.add_get('/{tail:.*}', http_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
print(f"Go to http://{host}:{port} to start chatting!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--host", "-H", help="host:port for the server", default="localhost:65432")
parser.add_argument("--proxy", "-p", help='proxy address like "http://localhost:7890"',
args = parser.parse_args()
print(f"Proxy used: {args.proxy}")
host, port =":")
port = int(port)
if os.path.isfile("cookies.json"):
with open("cookies.json", 'r') as f:
loaded_cookies = json.load(f)
print("Loaded cookies.json")
loaded_cookies = []
print("cookies.json not found")
loop = asyncio.get_event_loop()
loop.run_until_complete(main(host, port))
except KeyboardInterrupt: