hhz520 commited on
Commit
61517de
1 Parent(s): 280bd81

Upload 170 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +19 -0
  2. README.md +282 -11
  3. app.py +57 -0
  4. bot/baidu/baidu_unit_bot.py +36 -0
  5. bot/baidu/baidu_wenxin.py +107 -0
  6. bot/baidu/baidu_wenxin_session.py +53 -0
  7. bot/bot.py +17 -0
  8. bot/bot_factory.py +50 -0
  9. bot/chatgpt/chat_gpt_bot.py +194 -0
  10. bot/chatgpt/chat_gpt_session.py +101 -0
  11. bot/claude/claude_ai_bot.py +222 -0
  12. bot/claude/claude_ai_session.py +9 -0
  13. bot/linkai/link_ai_bot.py +404 -0
  14. bot/openai/open_ai_bot.py +122 -0
  15. bot/openai/open_ai_image.py +43 -0
  16. bot/openai/open_ai_session.py +73 -0
  17. bot/session_manager.py +91 -0
  18. bot/tongyi/tongyi_qwen_bot.py +185 -0
  19. bot/xunfei/xunfei_spark_bot.py +267 -0
  20. bridge/bridge.py +80 -0
  21. bridge/context.py +71 -0
  22. bridge/reply.py +31 -0
  23. channel/channel.py +43 -0
  24. channel/channel_factory.py +44 -0
  25. channel/chat_channel.py +392 -0
  26. channel/chat_message.py +87 -0
  27. channel/feishu/feishu_channel.py +250 -0
  28. channel/feishu/feishu_message.py +92 -0
  29. channel/terminal/terminal_channel.py +92 -0
  30. channel/wechat/wechat_channel.py +236 -0
  31. channel/wechat/wechat_message.py +102 -0
  32. channel/wechat/wechaty_channel.py +129 -0
  33. channel/wechat/wechaty_message.py +89 -0
  34. channel/wechatcom/README.md +85 -0
  35. channel/wechatcom/wechatcomapp_channel.py +178 -0
  36. channel/wechatcom/wechatcomapp_client.py +21 -0
  37. channel/wechatcom/wechatcomapp_message.py +52 -0
  38. channel/wechatmp/README.md +100 -0
  39. channel/wechatmp/active_reply.py +75 -0
  40. channel/wechatmp/common.py +27 -0
  41. channel/wechatmp/passive_reply.py +211 -0
  42. channel/wechatmp/wechatmp_channel.py +236 -0
  43. channel/wechatmp/wechatmp_client.py +49 -0
  44. channel/wechatmp/wechatmp_message.py +56 -0
  45. channel/wework/run.py +17 -0
  46. channel/wework/wework_channel.py +326 -0
  47. channel/wework/wework_message.py +211 -0
  48. common/const.py +23 -0
  49. common/dequeue.py +33 -0
  50. common/expired_dict.py +42 -0
LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022 zhayujie
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
README.md CHANGED
@@ -1,11 +1,282 @@
1
- ---
2
- title: Webchat
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- license: apache-2.0
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 简介
2
+
3
+ > ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。
4
+
5
+ 最新版本支持的功能如下:
6
+
7
+ - [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信、微信公众号和、业微信、飞书等部署方式
8
+ - [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, 文心一言, 讯飞星火
9
+ - [x] **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
10
+ - [x] **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, vision模型
11
+ - [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话等插件
12
+ - [X] **Tool工具:** 与操作系统和互联网交互,支持最新信息搜索、数学计算、天气和资讯查询、网页总结,基于 [chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 实现
13
+ - [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、领域知识库、智能客服使用,基于 [LinkAI](https://link-ai.tech/console) 实现
14
+
15
+ > 欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
16
+
17
+ # 演示
18
+
19
+ https://github.com/zhayujie/chatgpt-on-wechat/assets/26161723/d5154020-36e3-41db-8706-40ce9f3f1b1e
20
+
21
+ Demo made by [Visionn](https://www.wangpc.cc/)
22
+
23
+ # 交流群
24
+
25
+ 添加小助手微信进群,请备注 "wechat":
26
+
27
+ <img width="240" src="./docs/images/contact.jpg">
28
+
29
+ # 更新日志
30
+
31
+ >**2023.11.10:** [1.5.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.2),新增飞书通道、图像识别对话、黑名单配置
32
+
33
+ >**2023.11.10:** [1.5.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.0),新增 `gpt-4-turbo`, `dall-e-3`, `tts` 模型接入,完善图像理解&生成、语音识别&生成的多模态能力
34
+
35
+ >**2023.10.16:** 支持通过意图识别使用LinkAI联网搜索、数学计算、网页访问等插件,参考[插件文档](https://docs.link-ai.tech/platform/plugins)
36
+
37
+ >**2023.09.26:** 插件增加 文件/文章链接 一键总结和对话的功能,使用参考:[插件说明](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai#3%E6%96%87%E6%A1%A3%E6%80%BB%E7%BB%93%E5%AF%B9%E8%AF%9D%E5%8A%9F%E8%83%BD)
38
+
39
+ >**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
40
+
41
+ >**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
42
+
43
+ >**2023.04.26:** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944))
44
+
45
+ >**2023.04.05:** 支持微信公众号部署,兼容插件,并支持语音图片交互,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
46
+
47
+ >**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
48
+
49
+ >**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565))
50
+
51
+ >**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
52
+
53
+ >**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
54
+
55
+ # 快速开始
56
+
57
+ 快速开始文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start)
58
+
59
+ ## 准备
60
+
61
+ ### 1. 账号注册
62
+
63
+ 项目默认使用OpenAI接口,需前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。接口需要海外网络访问及绑定信用卡支付。
64
+
65
+ > 默认对话模型是 openai 的 gpt-3.5-turbo,计费方式是约每 1000tokens (约750个英文单词 或 500汉字,包含请求和回复) 消耗 $0.002,图片生成是Dell E模型,每张消耗 $0.016。
66
+
67
+ 项目同时也支持使用 LinkAI 接口,无需代理,可使用 文心、讯飞、GPT-3、GPT-4 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结和对话等能力。修改配置即可一键切换,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
68
+
69
+ ### 2.运行环境
70
+
71
+ 支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
72
+ > 建议Python版本在 3.7.1~3.9.X 之间,推荐3.8版本,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
73
+
74
+ > 注意:Docker 或 Railway 部署无需安装python环境和下载源码,可直接快进到下一节。
75
+
76
+ **(1) 克隆项目代码:**
77
+
78
+ ```bash
79
+ git clone https://github.com/zhayujie/chatgpt-on-wechat
80
+ cd chatgpt-on-wechat/
81
+ ```
82
+
83
+ **(2) 安装核心依赖 (必选):**
84
+ > 能够使用`itchat`创建机器人,并具有文字交流功能所需的最小依赖集合。
85
+ ```bash
86
+ pip3 install -r requirements.txt
87
+ ```
88
+
89
+ **(3) 拓展依赖 (可选,建议安装):**
90
+
91
+ ```bash
92
+ pip3 install -r requirements-optional.txt
93
+ ```
94
+ > 如果某项依赖安装失败请注释掉对应的行再继续。
95
+
96
+ 其中`tiktoken`要求`python`版本在3.8以上,它用于精确计算会话使用的tokens数量,强烈建议安装。
97
+
98
+
99
+ 使用`google`或`baidu`语音识别需安装`ffmpeg`,
100
+
101
+ 默认的`openai`语音识别不需要安装`ffmpeg`。
102
+
103
+ 参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)
104
+
105
+ 使用`azure`语音功能需安装依赖,并参考[文档](https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/quickstarts/setup-platform?pivots=programming-language-python&tabs=linux%2Cubuntu%2Cdotnet%2Cjre%2Cmaven%2Cnodejs%2Cmac%2Cpypi)的环境要求。
106
+ :
107
+
108
+ ```bash
109
+ pip3 install azure-cognitiveservices-speech
110
+ ```
111
+
112
+ ## 配置
113
+
114
+ 配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
115
+
116
+ ```bash
117
+ cp config-template.json config.json
118
+ ```
119
+
120
+ 然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(请去掉注释):
121
+
122
+ ```bash
123
+ # config.json文件内容示例
124
+ {
125
+ "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
126
+ "model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
127
+ "proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
128
+ "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
129
+ "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
130
+ "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
131
+ "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
132
+ "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
133
+ "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
134
+ "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
135
+ "speech_recognition": false, # 是否开启语音识别
136
+ "group_speech_recognition": false, # 是否开启群组语音识别
137
+ "use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
138
+ "azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称
139
+ "azure_api_version": "", # 采用Azure ChatGPT时,API版本
140
+ "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
141
+ # 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
142
+ "subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
143
+ "use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
144
+ "linkai_api_key": "", # LinkAI Api Key
145
+ "linkai_app_code": "" # LinkAI 应用code
146
+ }
147
+ ```
148
+ **配置说明:**
149
+
150
+ **1.个人聊天**
151
+
152
+ + 个人聊天中,需要以 "bot"或"@bot" 为开头的内容触发机器人,对应配置项 `single_chat_prefix` (如果不需要以前缀触发可以填写 `"single_chat_prefix": [""]`)
153
+ + 机器人回复的内容会以 "[bot] " 作为前缀, 以区分真人,对应的配置项为 `single_chat_reply_prefix` (如果不需要前缀可以填写 `"single_chat_reply_prefix": ""`)
154
+
155
+ **2.群组聊天**
156
+
157
+ + 群组聊天中,群名称需配置在 `group_name_white_list ` 中才能开启群聊自动回复。如果想对所有群聊生效,可以直接填写 `"group_name_white_list": ["ALL_GROUP"]`
158
+ + 默认只要被人 @ 就会触发机器人自动回复;另外群聊天中只要检测到以 "@bot" 开头的内容,同样会自动回复(方便自己触发),这对应配置项 `group_chat_prefix`
159
+ + 可选配置: `group_name_keyword_white_list`配置项支持模糊匹配群名称,`group_chat_keyword`配置项则支持模糊匹配群消息内容,用法与上述两个配置项相同。(Contributed by [evolay](https://github.com/evolay))
160
+ + `group_chat_in_one_session`:使群聊共享一个会话上下文,配置 `["ALL_GROUP"]` 则作用于所有群聊
161
+
162
+ **3.语音识别**
163
+
164
+ + 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
165
+ + 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
166
+ + 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊),但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
167
+
168
+ **4.其他配置**
169
+
170
+ + `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` , `claude` , `xunfei`(其中gpt-4 api暂未完全开放,申请通过后可使用)
171
+ + `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
172
+ + `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
173
+ + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
174
+ + 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档,在[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中检查哪些参数在本项目中是可配置的。
175
+ + `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
176
+ + `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
177
+ + `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
178
+ + `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
179
+ + `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作��他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
180
+ + `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
181
+
182
+ **5.LinkAI配置 (可选)**
183
+
184
+ + `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
185
+ + `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
186
+ + `linkai_app_code`: LinkAI 应用code,选填
187
+
188
+ **本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
189
+
190
+ ## 运行
191
+
192
+ ### 1.本地运行
193
+
194
+ 如果是开发机 **本地运行**,直接在项目根目录下执行:
195
+
196
+ ```bash
197
+ python3 app.py # windows环境下该命令通常为 python app.py
198
+ ```
199
+
200
+ 终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
201
+
202
+ ### 2.服务器部署
203
+
204
+ 使用nohup命令在后台运行程序:
205
+
206
+ ```bash
207
+ touch nohup.out # 首次运行需要新建日志文件
208
+ nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
209
+ ```
210
+ 扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
211
+
212
+ > **多账号支持:** 将项目复制多份,分别启动程序,用不同账号扫码登录即可实现同时运行。
213
+
214
+ > **特殊指令:** 用户向机器人发送 **#reset** 即可清空该用户的上下文记忆。
215
+
216
+
217
+ ### 3.Docker部署
218
+
219
+ > 使用docker部署无需下载源码和安装依赖,只需要获取 docker-compose.yml 配置文件并启动容器即可。
220
+
221
+ > 前提是需要安装好 `docker` 及 `docker-compose`,安装成功的表现是执行 `docker -v` 和 `docker-compose version` (或 docker compose version) 可以查看到版本号,可前往 [docker官网](https://docs.docker.com/engine/install/) 进行下载。
222
+
223
+ #### (1) 下载 docker-compose.yml 文件
224
+
225
+ ```bash
226
+ wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
227
+ ```
228
+
229
+ 下载完成后打开 `docker-compose.yml` 修改所需配置,如 `OPEN_AI_API_KEY` 和 `GROUP_NAME_WHITE_LIST` 等。
230
+
231
+ #### (2) 启动容器
232
+
233
+ 在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
234
+
235
+ ```bash
236
+ sudo docker compose up -d
237
+ ```
238
+
239
+ 运行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。
240
+
241
+ 注意:
242
+
243
+ - 如果 `docker-compose` 是 1.X 版本 则需要执行 `sudo docker-compose up -d` 来启动容器
244
+ - 该命令会自动去 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取 latest 版本的镜像,latest 镜像会在每次项目 release 新的版本时生成
245
+
246
+ 最后运行以下命令可查看容器运行日志,扫描日志中的二维码即可完成登录:
247
+
248
+ ```bash
249
+ sudo docker logs -f chatgpt-on-wechat
250
+ ```
251
+
252
+ #### (3) 插件使用
253
+
254
+ 如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
255
+ 重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
256
+
257
+ ```
258
+ volumes:
259
+ - ./config.json:/app/plugins/config.json
260
+ ```
261
+
262
+ ### 4. Railway部署
263
+
264
+ > Railway 每月提供5刀和最多500小时的免费额度。 (07.11更新: 目前大部分账号已无法免费部署)
265
+
266
+ 1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)
267
+ 2. 点击 `Deploy Now` 按钮。
268
+ 3. 设置环境变量来重载程序运行的参数,例如`open_ai_api_key`, `character_desc`。
269
+
270
+ **一键部署:**
271
+
272
+ [![Deploy on Railway](https://railway.app/button.svg)](https://railway.app/template/qApznZ?referralCode=RC3znh)
273
+
274
+ ## 常见问题
275
+
276
+ FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
277
+
278
+ 或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (beta版本,语料完善中,回复仅供参考)
279
+
280
+ ## 联系
281
+
282
+ 欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。参与更多讨论可加入技术交流群。
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ import os
4
+ import signal
5
+ import sys
6
+
7
+ from channel import channel_factory
8
+ from common import const
9
+ from config import load_config
10
+ from plugins import *
11
+
12
+
13
+ def sigterm_handler_wrap(_signo):
14
+ old_handler = signal.getsignal(_signo)
15
+
16
+ def func(_signo, _stack_frame):
17
+ logger.info("signal {} received, exiting...".format(_signo))
18
+ conf().save_user_datas()
19
+ if callable(old_handler): # check old_handler
20
+ return old_handler(_signo, _stack_frame)
21
+ sys.exit(0)
22
+
23
+ signal.signal(_signo, func)
24
+
25
+
26
+ def run():
27
+ try:
28
+ # load config
29
+ load_config()
30
+ # ctrl + c
31
+ sigterm_handler_wrap(signal.SIGINT)
32
+ # kill signal
33
+ sigterm_handler_wrap(signal.SIGTERM)
34
+
35
+ # create channel
36
+ channel_name = conf().get("channel_type", "wx")
37
+
38
+ if "--cmd" in sys.argv:
39
+ channel_name = "terminal"
40
+
41
+ if channel_name == "wxy":
42
+ os.environ["WECHATY_LOG"] = "warn"
43
+ # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
44
+
45
+ channel = channel_factory.create_channel(channel_name)
46
+ if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU]:
47
+ PluginManager().load_plugins()
48
+
49
+ # startup channel
50
+ channel.startup()
51
+ except Exception as e:
52
+ logger.error("App startup failed!")
53
+ logger.exception(e)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ run()
bot/baidu/baidu_unit_bot.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ import requests
4
+
5
+ from bot.bot import Bot
6
+ from bridge.reply import Reply, ReplyType
7
+
8
+
9
+ # Baidu Unit对话接口 (可用, 但能力较弱)
10
+ class BaiduUnitBot(Bot):
11
+ def reply(self, query, context=None):
12
+ token = self.get_token()
13
+ url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
14
+ post_data = (
15
+ '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
16
+ + query
17
+ + '", "hyper_params": {"chat_custom_bot_profile": 1}}}'
18
+ )
19
+ print(post_data)
20
+ headers = {"content-type": "application/x-www-form-urlencoded"}
21
+ response = requests.post(url, data=post_data.encode(), headers=headers)
22
+ if response:
23
+ reply = Reply(
24
+ ReplyType.TEXT,
25
+ response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1],
26
+ )
27
+ return reply
28
+
29
+ def get_token(self):
30
+ access_key = "YOUR_ACCESS_KEY"
31
+ secret_key = "YOUR_SECRET_KEY"
32
+ host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
33
+ response = requests.get(host)
34
+ if response:
35
+ print(response.json())
36
+ return response.json()["access_token"]
bot/baidu/baidu_wenxin.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ import requests, json
4
+ from bot.bot import Bot
5
+ from bot.session_manager import SessionManager
6
+ from bridge.context import ContextType
7
+ from bridge.reply import Reply, ReplyType
8
+ from common.log import logger
9
+ from config import conf
10
+ from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
11
+
12
+ BAIDU_API_KEY = conf().get("baidu_wenxin_api_key")
13
+ BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key")
14
+
15
+ class BaiduWenxinBot(Bot):
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+ wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant"
20
+ if conf().get("model") and conf().get("model") == "wenxin-4":
21
+ wenxin_model = "completions_pro"
22
+ self.sessions = SessionManager(BaiduWenxinSession, model=wenxin_model)
23
+
24
+ def reply(self, query, context=None):
25
+ # acquire reply content
26
+ if context and context.type:
27
+ if context.type == ContextType.TEXT:
28
+ logger.info("[BAIDU] query={}".format(query))
29
+ session_id = context["session_id"]
30
+ reply = None
31
+ if query == "#清除记忆":
32
+ self.sessions.clear_session(session_id)
33
+ reply = Reply(ReplyType.INFO, "记忆已清除")
34
+ elif query == "#清除所有":
35
+ self.sessions.clear_all_session()
36
+ reply = Reply(ReplyType.INFO, "所有人记忆已清除")
37
+ else:
38
+ session = self.sessions.session_query(query, session_id)
39
+ result = self.reply_text(session)
40
+ total_tokens, completion_tokens, reply_content = (
41
+ result["total_tokens"],
42
+ result["completion_tokens"],
43
+ result["content"],
44
+ )
45
+ logger.debug(
46
+ "[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens)
47
+ )
48
+
49
+ if total_tokens == 0:
50
+ reply = Reply(ReplyType.ERROR, reply_content)
51
+ else:
52
+ self.sessions.session_reply(reply_content, session_id, total_tokens)
53
+ reply = Reply(ReplyType.TEXT, reply_content)
54
+ return reply
55
+ elif context.type == ContextType.IMAGE_CREATE:
56
+ ok, retstring = self.create_img(query, 0)
57
+ reply = None
58
+ if ok:
59
+ reply = Reply(ReplyType.IMAGE_URL, retstring)
60
+ else:
61
+ reply = Reply(ReplyType.ERROR, retstring)
62
+ return reply
63
+
64
+ def reply_text(self, session: BaiduWenxinSession, retry_count=0):
65
+ try:
66
+ logger.info("[BAIDU] model={}".format(session.model))
67
+ access_token = self.get_access_token()
68
+ if access_token == 'None':
69
+ logger.warn("[BAIDU] access token 获取失败")
70
+ return {
71
+ "total_tokens": 0,
72
+ "completion_tokens": 0,
73
+ "content": 0,
74
+ }
75
+ url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token
76
+ headers = {
77
+ 'Content-Type': 'application/json'
78
+ }
79
+ payload = {'messages': session.messages}
80
+ response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
81
+ response_text = json.loads(response.text)
82
+ logger.info(f"[BAIDU] response text={response_text}")
83
+ res_content = response_text["result"]
84
+ total_tokens = response_text["usage"]["total_tokens"]
85
+ completion_tokens = response_text["usage"]["completion_tokens"]
86
+ logger.info("[BAIDU] reply={}".format(res_content))
87
+ return {
88
+ "total_tokens": total_tokens,
89
+ "completion_tokens": completion_tokens,
90
+ "content": res_content,
91
+ }
92
+ except Exception as e:
93
+ need_retry = retry_count < 2
94
+ logger.warn("[BAIDU] Exception: {}".format(e))
95
+ need_retry = False
96
+ self.sessions.clear_session(session.session_id)
97
+ result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
98
+ return result
99
+
100
+ def get_access_token(self):
101
+ """
102
+ 使用 AK,SK 生成鉴权签名(Access Token)
103
+ :return: access_token,或是None(如果错误)
104
+ """
105
+ url = "https://aip.baidubce.com/oauth/2.0/token"
106
+ params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY}
107
+ return str(requests.post(url, params=params).json().get("access_token"))
bot/baidu/baidu_wenxin_session.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bot.session_manager import Session
2
+ from common.log import logger
3
+
4
+ """
5
+ e.g. [
6
+ {"role": "user", "content": "Who won the world series in 2020?"},
7
+ {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
8
+ {"role": "user", "content": "Where was it played?"}
9
+ ]
10
+ """
11
+
12
+
13
+ class BaiduWenxinSession(Session):
14
+ def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
15
+ super().__init__(session_id, system_prompt)
16
+ self.model = model
17
+ # 百度文心不支持system prompt
18
+ # self.reset()
19
+
20
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
21
+ precise = True
22
+ try:
23
+ cur_tokens = self.calc_tokens()
24
+ except Exception as e:
25
+ precise = False
26
+ if cur_tokens is None:
27
+ raise e
28
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
29
+ while cur_tokens > max_tokens:
30
+ if len(self.messages) >= 2:
31
+ self.messages.pop(0)
32
+ self.messages.pop(0)
33
+ else:
34
+ logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
35
+ break
36
+ if precise:
37
+ cur_tokens = self.calc_tokens()
38
+ else:
39
+ cur_tokens = cur_tokens - max_tokens
40
+ return cur_tokens
41
+
42
+ def calc_tokens(self):
43
+ return num_tokens_from_messages(self.messages, self.model)
44
+
45
+
46
+ def num_tokens_from_messages(messages, model):
47
+ """Returns the number of tokens used by a list of messages."""
48
+ tokens = 0
49
+ for msg in messages:
50
+ # 官方token计算规则暂不明确: "大约为 token数为 "中文字 + 其他语种单词数 x 1.3"
51
+ # 这里先直接根据字数粗略估算吧,暂不影响正常使用,仅在判断是否丢弃历史会话的时候会有偏差
52
+ tokens += len(msg["content"])
53
+ return tokens
bot/bot.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Auto-replay chat robot abstract class
3
+ """
4
+
5
+
6
+ from bridge.context import Context
7
+ from bridge.reply import Reply
8
+
9
+
10
+ class Bot(object):
11
+ def reply(self, query, context: Context = None) -> Reply:
12
+ """
13
+ bot auto-reply content
14
+ :param req: received message
15
+ :return: reply content
16
+ """
17
+ raise NotImplementedError
bot/bot_factory.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ channel factory
3
+ """
4
+ from common import const
5
+
6
+
7
+ def create_bot(bot_type):
8
+ """
9
+ create a bot_type instance
10
+ :param bot_type: bot type code
11
+ :return: bot instance
12
+ """
13
+ if bot_type == const.BAIDU:
14
+ # 替换Baidu Unit为Baidu文心千帆对话接口
15
+ # from bot.baidu.baidu_unit_bot import BaiduUnitBot
16
+ # return BaiduUnitBot()
17
+ from bot.baidu.baidu_wenxin import BaiduWenxinBot
18
+ return BaiduWenxinBot()
19
+
20
+ elif bot_type == const.CHATGPT:
21
+ # ChatGPT 网页端web接口
22
+ from bot.chatgpt.chat_gpt_bot import ChatGPTBot
23
+ return ChatGPTBot()
24
+
25
+ elif bot_type == const.OPEN_AI:
26
+ # OpenAI 官方对话模型API
27
+ from bot.openai.open_ai_bot import OpenAIBot
28
+ return OpenAIBot()
29
+
30
+ elif bot_type == const.CHATGPTONAZURE:
31
+ # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
32
+ from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
33
+ return AzureChatGPTBot()
34
+
35
+ elif bot_type == const.XUNFEI:
36
+ from bot.xunfei.xunfei_spark_bot import XunFeiBot
37
+ return XunFeiBot()
38
+
39
+ elif bot_type == const.LINKAI:
40
+ from bot.linkai.link_ai_bot import LinkAIBot
41
+ return LinkAIBot()
42
+
43
+ elif bot_type == const.CLAUDEAI:
44
+ from bot.claude.claude_ai_bot import ClaudeAIBot
45
+ return ClaudeAIBot()
46
+
47
+ elif bot_type == const.QWEN:
48
+ from bot.tongyi.tongyi_qwen_bot import TongyiQwenBot
49
+ return TongyiQwenBot()
50
+ raise RuntimeError
bot/chatgpt/chat_gpt_bot.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ import time
4
+
5
+ import openai
6
+ import openai.error
7
+ import requests
8
+
9
+ from bot.bot import Bot
10
+ from bot.chatgpt.chat_gpt_session import ChatGPTSession
11
+ from bot.openai.open_ai_image import OpenAIImage
12
+ from bot.session_manager import SessionManager
13
+ from bridge.context import ContextType
14
+ from bridge.reply import Reply, ReplyType
15
+ from common.log import logger
16
+ from common.token_bucket import TokenBucket
17
+ from config import conf, load_config
18
+
19
+
20
+ # OpenAI对话模型API (可用)
21
+ class ChatGPTBot(Bot, OpenAIImage):
22
+ def __init__(self):
23
+ super().__init__()
24
+ # set the default api_key
25
+ openai.api_key = conf().get("open_ai_api_key")
26
+ if conf().get("open_ai_api_base"):
27
+ openai.api_base = conf().get("open_ai_api_base")
28
+ proxy = conf().get("proxy")
29
+ if proxy:
30
+ openai.proxy = proxy
31
+ if conf().get("rate_limit_chatgpt"):
32
+ self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
33
+
34
+ self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
35
+ self.args = {
36
+ "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
37
+ "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
38
+ # "max_tokens":4096, # 回复最大的字符数
39
+ "top_p": conf().get("top_p", 1),
40
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
41
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
42
+ "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
43
+ "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
44
+ }
45
+
46
+ def reply(self, query, context=None):
47
+ # acquire reply content
48
+ if context.type == ContextType.TEXT:
49
+ logger.info("[CHATGPT] query={}".format(query))
50
+
51
+ session_id = context["session_id"]
52
+ reply = None
53
+ clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
54
+ if query in clear_memory_commands:
55
+ self.sessions.clear_session(session_id)
56
+ reply = Reply(ReplyType.INFO, "记忆已清除")
57
+ elif query == "#清除所有":
58
+ self.sessions.clear_all_session()
59
+ reply = Reply(ReplyType.INFO, "所有人记忆已清除")
60
+ elif query == "#更新配置":
61
+ load_config()
62
+ reply = Reply(ReplyType.INFO, "配置已更新")
63
+ if reply:
64
+ return reply
65
+ session = self.sessions.session_query(query, session_id)
66
+ logger.debug("[CHATGPT] session query={}".format(session.messages))
67
+
68
+ api_key = context.get("openai_api_key")
69
+ model = context.get("gpt_model")
70
+ new_args = None
71
+ if model:
72
+ new_args = self.args.copy()
73
+ new_args["model"] = model
74
+ # if context.get('stream'):
75
+ # # reply in stream
76
+ # return self.reply_text_stream(query, new_query, session_id)
77
+
78
+ reply_content = self.reply_text(session, api_key, args=new_args)
79
+ logger.debug(
80
+ "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
81
+ session.messages,
82
+ session_id,
83
+ reply_content["content"],
84
+ reply_content["completion_tokens"],
85
+ )
86
+ )
87
+ if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
88
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
89
+ elif reply_content["completion_tokens"] > 0:
90
+ self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
91
+ reply = Reply(ReplyType.TEXT, reply_content["content"])
92
+ else:
93
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
94
+ logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
95
+ return reply
96
+
97
+ elif context.type == ContextType.IMAGE_CREATE:
98
+ ok, retstring = self.create_img(query, 0)
99
+ reply = None
100
+ if ok:
101
+ reply = Reply(ReplyType.IMAGE_URL, retstring)
102
+ else:
103
+ reply = Reply(ReplyType.ERROR, retstring)
104
+ return reply
105
+ else:
106
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
107
+ return reply
108
+
109
+ def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
110
+ """
111
+ call openai's ChatCompletion to get the answer
112
+ :param session: a conversation session
113
+ :param session_id: session id
114
+ :param retry_count: retry count
115
+ :return: {}
116
+ """
117
+ try:
118
+ if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
119
+ raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
120
+ # if api_key == None, the default openai.api_key will be used
121
+ if args is None:
122
+ args = self.args
123
+ response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
124
+ # logger.debug("[CHATGPT] response={}".format(response))
125
+ # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
126
+ return {
127
+ "total_tokens": response["usage"]["total_tokens"],
128
+ "completion_tokens": response["usage"]["completion_tokens"],
129
+ "content": response.choices[0]["message"]["content"],
130
+ }
131
+ except Exception as e:
132
+ need_retry = retry_count < 2
133
+ result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
134
+ if isinstance(e, openai.error.RateLimitError):
135
+ logger.warn("[CHATGPT] RateLimitError: {}".format(e))
136
+ result["content"] = "提问太快啦,请休息一下再问我吧"
137
+ if need_retry:
138
+ time.sleep(20)
139
+ elif isinstance(e, openai.error.Timeout):
140
+ logger.warn("[CHATGPT] Timeout: {}".format(e))
141
+ result["content"] = "我没有收到你的消息"
142
+ if need_retry:
143
+ time.sleep(5)
144
+ elif isinstance(e, openai.error.APIError):
145
+ logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
146
+ result["content"] = "请再问我一次"
147
+ if need_retry:
148
+ time.sleep(10)
149
+ elif isinstance(e, openai.error.APIConnectionError):
150
+ logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
151
+ result["content"] = "我连接不到你的网络"
152
+ if need_retry:
153
+ time.sleep(5)
154
+ else:
155
+ logger.exception("[CHATGPT] Exception: {}".format(e))
156
+ need_retry = False
157
+ self.sessions.clear_session(session.session_id)
158
+
159
+ if need_retry:
160
+ logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
161
+ return self.reply_text(session, api_key, args, retry_count + 1)
162
+ else:
163
+ return result
164
+
165
+
166
+ class AzureChatGPTBot(ChatGPTBot):
167
+ def __init__(self):
168
+ super().__init__()
169
+ openai.api_type = "azure"
170
+ openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
171
+ self.args["deployment_id"] = conf().get("azure_deployment_id")
172
+
173
+ def create_img(self, query, retry_count=0, api_key=None):
174
+ api_version = "2022-08-03-preview"
175
+ url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version)
176
+ api_key = api_key or openai.api_key
177
+ headers = {"api-key": api_key, "Content-Type": "application/json"}
178
+ try:
179
+ body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
180
+ submission = requests.post(url, headers=headers, json=body)
181
+ operation_location = submission.headers["Operation-Location"]
182
+ retry_after = submission.headers["Retry-after"]
183
+ status = ""
184
+ image_url = ""
185
+ while status != "Succeeded":
186
+ logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
187
+ time.sleep(int(retry_after))
188
+ response = requests.get(operation_location, headers=headers)
189
+ status = response.json()["status"]
190
+ image_url = response.json()["result"]["contentUrl"]
191
+ return True, image_url
192
+ except Exception as e:
193
+ logger.error("create image error: {}".format(e))
194
+ return False, "图片生成失败"
bot/chatgpt/chat_gpt_session.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bot.session_manager import Session
2
+ from common.log import logger
3
+ from common import const
4
+
5
+ """
6
+ e.g. [
7
+ {"role": "system", "content": "You are a helpful assistant."},
8
+ {"role": "user", "content": "Who won the world series in 2020?"},
9
+ {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
10
+ {"role": "user", "content": "Where was it played?"}
11
+ ]
12
+ """
13
+
14
+
15
+ class ChatGPTSession(Session):
16
+ def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
17
+ super().__init__(session_id, system_prompt)
18
+ self.model = model
19
+ self.reset()
20
+
21
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
22
+ precise = True
23
+ try:
24
+ cur_tokens = self.calc_tokens()
25
+ except Exception as e:
26
+ precise = False
27
+ if cur_tokens is None:
28
+ raise e
29
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
30
+ while cur_tokens > max_tokens:
31
+ if len(self.messages) > 2:
32
+ self.messages.pop(1)
33
+ elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
34
+ self.messages.pop(1)
35
+ if precise:
36
+ cur_tokens = self.calc_tokens()
37
+ else:
38
+ cur_tokens = cur_tokens - max_tokens
39
+ break
40
+ elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
41
+ logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
42
+ break
43
+ else:
44
+ logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
45
+ break
46
+ if precise:
47
+ cur_tokens = self.calc_tokens()
48
+ else:
49
+ cur_tokens = cur_tokens - max_tokens
50
+ return cur_tokens
51
+
52
+ def calc_tokens(self):
53
+ return num_tokens_from_messages(self.messages, self.model)
54
+
55
+
56
+ # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
57
+ def num_tokens_from_messages(messages, model):
58
+ """Returns the number of tokens used by a list of messages."""
59
+
60
+ if model in ["wenxin", "xunfei"]:
61
+ return num_tokens_by_character(messages)
62
+
63
+ import tiktoken
64
+
65
+ if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]:
66
+ return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
67
+ elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
68
+ "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]:
69
+ return num_tokens_from_messages(messages, model="gpt-4")
70
+
71
+ try:
72
+ encoding = tiktoken.encoding_for_model(model)
73
+ except KeyError:
74
+ logger.debug("Warning: model not found. Using cl100k_base encoding.")
75
+ encoding = tiktoken.get_encoding("cl100k_base")
76
+ if model == "gpt-3.5-turbo":
77
+ tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
78
+ tokens_per_name = -1 # if there's a name, the role is omitted
79
+ elif model == "gpt-4":
80
+ tokens_per_message = 3
81
+ tokens_per_name = 1
82
+ else:
83
+ logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
84
+ return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
85
+ num_tokens = 0
86
+ for message in messages:
87
+ num_tokens += tokens_per_message
88
+ for key, value in message.items():
89
+ num_tokens += len(encoding.encode(value))
90
+ if key == "name":
91
+ num_tokens += tokens_per_name
92
+ num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
93
+ return num_tokens
94
+
95
+
96
+ def num_tokens_by_character(messages):
97
+ """Returns the number of tokens used by a list of messages."""
98
+ tokens = 0
99
+ for msg in messages:
100
+ tokens += len(msg["content"])
101
+ return tokens
bot/claude/claude_ai_bot.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+ import json
4
+ import uuid
5
+ from curl_cffi import requests
6
+ from bot.bot import Bot
7
+ from bot.claude.claude_ai_session import ClaudeAiSession
8
+ from bot.openai.open_ai_image import OpenAIImage
9
+ from bot.session_manager import SessionManager
10
+ from bridge.context import Context, ContextType
11
+ from bridge.reply import Reply, ReplyType
12
+ from common.log import logger
13
+ from config import conf
14
+
15
+
16
+ class ClaudeAIBot(Bot, OpenAIImage):
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
20
+ self.claude_api_cookie = conf().get("claude_api_cookie")
21
+ self.proxy = conf().get("proxy")
22
+ self.con_uuid_dic = {}
23
+ if self.proxy:
24
+ self.proxies = {
25
+ "http": self.proxy,
26
+ "https": self.proxy
27
+ }
28
+ else:
29
+ self.proxies = None
30
+ self.error = ""
31
+ self.org_uuid = self.get_organization_id()
32
+
33
+ def generate_uuid(self):
34
+ random_uuid = uuid.uuid4()
35
+ random_uuid_str = str(random_uuid)
36
+ formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
37
+ return formatted_uuid
38
+
39
+ def reply(self, query, context: Context = None) -> Reply:
40
+ if context.type == ContextType.TEXT:
41
+ return self._chat(query, context)
42
+ elif context.type == ContextType.IMAGE_CREATE:
43
+ ok, res = self.create_img(query, 0)
44
+ if ok:
45
+ reply = Reply(ReplyType.IMAGE_URL, res)
46
+ else:
47
+ reply = Reply(ReplyType.ERROR, res)
48
+ return reply
49
+ else:
50
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
51
+ return reply
52
+
53
+ def get_organization_id(self):
54
+ url = "https://claude.ai/api/organizations"
55
+ headers = {
56
+ 'User-Agent':
57
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
58
+ 'Accept-Language': 'en-US,en;q=0.5',
59
+ 'Referer': 'https://claude.ai/chats',
60
+ 'Content-Type': 'application/json',
61
+ 'Sec-Fetch-Dest': 'empty',
62
+ 'Sec-Fetch-Mode': 'cors',
63
+ 'Sec-Fetch-Site': 'same-origin',
64
+ 'Connection': 'keep-alive',
65
+ 'Cookie': f'{self.claude_api_cookie}'
66
+ }
67
+ try:
68
+ response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
69
+ res = json.loads(response.text)
70
+ uuid = res[0]['uuid']
71
+ except:
72
+ if "App unavailable" in response.text:
73
+ logger.error("IP error: The IP is not allowed to be used on Claude")
74
+ self.error = "ip所在地区不被claude支持"
75
+ elif "Invalid authorization" in response.text:
76
+ logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
77
+ self.error = "无法通过claude身份验证,请检查cookie"
78
+ return None
79
+ return uuid
80
+
81
+ def conversation_share_check(self,session_id):
82
+ if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
83
+ con_uuid = conf().get("claude_uuid")
84
+ return con_uuid
85
+ if session_id not in self.con_uuid_dic:
86
+ self.con_uuid_dic[session_id] = self.generate_uuid()
87
+ self.create_new_chat(self.con_uuid_dic[session_id])
88
+ return self.con_uuid_dic[session_id]
89
+
90
+ def check_cookie(self):
91
+ flag = self.get_organization_id()
92
+ return flag
93
+
94
+ def create_new_chat(self, con_uuid):
95
+ """
96
+ 新建claude对话实体
97
+ :param con_uuid: 对话id
98
+ :return:
99
+ """
100
+ url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
101
+ payload = json.dumps({"uuid": con_uuid, "name": ""})
102
+ headers = {
103
+ 'User-Agent':
104
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
105
+ 'Accept-Language': 'en-US,en;q=0.5',
106
+ 'Referer': 'https://claude.ai/chats',
107
+ 'Content-Type': 'application/json',
108
+ 'Origin': 'https://claude.ai',
109
+ 'DNT': '1',
110
+ 'Connection': 'keep-alive',
111
+ 'Cookie': self.claude_api_cookie,
112
+ 'Sec-Fetch-Dest': 'empty',
113
+ 'Sec-Fetch-Mode': 'cors',
114
+ 'Sec-Fetch-Site': 'same-origin',
115
+ 'TE': 'trailers'
116
+ }
117
+ response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
118
+ # Returns JSON of the newly created conversation information
119
+ return response.json()
120
+
121
+ def _chat(self, query, context, retry_count=0) -> Reply:
122
+ """
123
+ 发起对话请求
124
+ :param query: 请求提示词
125
+ :param context: 对话上下文
126
+ :param retry_count: 当前递归重试次数
127
+ :return: 回复
128
+ """
129
+ if retry_count >= 2:
130
+ # exit from retry 2 times
131
+ logger.warn("[CLAUDEAI] failed after maximum number of retry times")
132
+ return Reply(ReplyType.ERROR, "请再问我一次吧")
133
+
134
+ try:
135
+ session_id = context["session_id"]
136
+ if self.org_uuid is None:
137
+ return Reply(ReplyType.ERROR, self.error)
138
+
139
+ session = self.sessions.session_query(query, session_id)
140
+ con_uuid = self.conversation_share_check(session_id)
141
+
142
+ model = conf().get("model") or "gpt-3.5-turbo"
143
+ # remove system message
144
+ if session.messages[0].get("role") == "system":
145
+ if model == "wenxin" or model == "claude":
146
+ session.messages.pop(0)
147
+ logger.info(f"[CLAUDEAI] query={query}")
148
+
149
+ # do http request
150
+ base_url = "https://claude.ai"
151
+ payload = json.dumps({
152
+ "completion": {
153
+ "prompt": f"{query}",
154
+ "timezone": "Asia/Kolkata",
155
+ "model": "claude-2"
156
+ },
157
+ "organization_uuid": f"{self.org_uuid}",
158
+ "conversation_uuid": f"{con_uuid}",
159
+ "text": f"{query}",
160
+ "attachments": []
161
+ })
162
+ headers = {
163
+ 'User-Agent':
164
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
165
+ 'Accept': 'text/event-stream, text/event-stream',
166
+ 'Accept-Language': 'en-US,en;q=0.5',
167
+ 'Referer': 'https://claude.ai/chats',
168
+ 'Content-Type': 'application/json',
169
+ 'Origin': 'https://claude.ai',
170
+ 'DNT': '1',
171
+ 'Connection': 'keep-alive',
172
+ 'Cookie': f'{self.claude_api_cookie}',
173
+ 'Sec-Fetch-Dest': 'empty',
174
+ 'Sec-Fetch-Mode': 'cors',
175
+ 'Sec-Fetch-Site': 'same-origin',
176
+ 'TE': 'trailers'
177
+ }
178
+
179
+ res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
180
+ if res.status_code == 200 or "pemission" in res.text:
181
+ # execute success
182
+ decoded_data = res.content.decode("utf-8")
183
+ decoded_data = re.sub('\n+', '\n', decoded_data).strip()
184
+ data_strings = decoded_data.split('\n')
185
+ completions = []
186
+ for data_string in data_strings:
187
+ json_str = data_string[6:].strip()
188
+ data = json.loads(json_str)
189
+ if 'completion' in data:
190
+ completions.append(data['completion'])
191
+
192
+ reply_content = ''.join(completions)
193
+
194
+ if "rate limi" in reply_content:
195
+ logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
196
+ return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间")
197
+ logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
198
+ self.sessions.session_reply(reply_content, session_id, 100)
199
+ return Reply(ReplyType.TEXT, reply_content)
200
+ else:
201
+ flag = self.check_cookie()
202
+ if flag == None:
203
+ return Reply(ReplyType.ERROR, self.error)
204
+
205
+ response = res.json()
206
+ error = response.get("error")
207
+ logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
208
+ f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
209
+
210
+ if res.status_code >= 500:
211
+ # server error, need retry
212
+ time.sleep(2)
213
+ logger.warn(f"[CLAUDE] do retry, times={retry_count}")
214
+ return self._chat(query, context, retry_count + 1)
215
+ return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
216
+
217
+ except Exception as e:
218
+ logger.exception(e)
219
+ # retry
220
+ time.sleep(2)
221
+ logger.warn(f"[CLAUDE] do retry, times={retry_count}")
222
+ return self._chat(query, context, retry_count + 1)
bot/claude/claude_ai_session.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from bot.session_manager import Session
2
+
3
+
4
+ class ClaudeAiSession(Session):
5
+ def __init__(self, session_id, system_prompt=None, model="claude"):
6
+ super().__init__(session_id, system_prompt)
7
+ self.model = model
8
+ # claude逆向不支持role prompt
9
+ # self.reset()
bot/linkai/link_ai_bot.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # access LinkAI knowledge base platform
2
+ # docs: https://link-ai.tech/platform/link-app/wechat
3
+
4
+ import time
5
+
6
+ import requests
7
+
8
+ import config
9
+ from bot.bot import Bot
10
+ from bot.chatgpt.chat_gpt_session import ChatGPTSession
11
+ from bot.session_manager import SessionManager
12
+ from bridge.context import Context, ContextType
13
+ from bridge.reply import Reply, ReplyType
14
+ from common.log import logger
15
+ from config import conf, pconf
16
+ import threading
17
+ from common import memory, utils
18
+ import base64
19
+
20
+
21
+ class LinkAIBot(Bot):
22
+ # authentication failed
23
+ AUTH_FAILED_CODE = 401
24
+ NO_QUOTA_CODE = 406
25
+
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
29
+ self.args = {}
30
+
31
+ def reply(self, query, context: Context = None) -> Reply:
32
+ if context.type == ContextType.TEXT:
33
+ return self._chat(query, context)
34
+ elif context.type == ContextType.IMAGE_CREATE:
35
+ if not conf().get("text_to_image"):
36
+ logger.warn("[LinkAI] text_to_image is not enabled, ignore the IMAGE_CREATE request")
37
+ return Reply(ReplyType.TEXT, "")
38
+ ok, res = self.create_img(query, 0)
39
+ if ok:
40
+ reply = Reply(ReplyType.IMAGE_URL, res)
41
+ else:
42
+ reply = Reply(ReplyType.ERROR, res)
43
+ return reply
44
+ else:
45
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
46
+ return reply
47
+
48
+ def _chat(self, query, context, retry_count=0) -> Reply:
49
+ """
50
+ 发起对话请求
51
+ :param query: 请求提示词
52
+ :param context: 对话上下文
53
+ :param retry_count: 当前递归重试次数
54
+ :return: 回复
55
+ """
56
+ if retry_count > 2:
57
+ # exit from retry 2 times
58
+ logger.warn("[LINKAI] failed after maximum number of retry times")
59
+ return Reply(ReplyType.TEXT, "请再问我一次吧")
60
+
61
+ try:
62
+ # load config
63
+ if context.get("generate_breaked_by"):
64
+ logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
65
+ app_code = None
66
+ else:
67
+ plugin_app_code = self._find_group_mapping_code(context)
68
+ app_code = context.kwargs.get("app_code") or plugin_app_code or conf().get("linkai_app_code")
69
+ linkai_api_key = conf().get("linkai_api_key")
70
+
71
+ session_id = context["session_id"]
72
+ session_message = self.sessions.session_msg_query(query, session_id)
73
+ logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")
74
+
75
+ # image process
76
+ img_cache = memory.USER_IMAGE_CACHE.get(session_id)
77
+ if img_cache:
78
+ messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
79
+ if messages:
80
+ session_message = messages
81
+
82
+ model = conf().get("model")
83
+ # remove system message
84
+ if session_message[0].get("role") == "system":
85
+ if app_code or model == "wenxin":
86
+ session_message.pop(0)
87
+
88
+ body = {
89
+ "app_code": app_code,
90
+ "messages": session_message,
91
+ "model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
92
+ "temperature": conf().get("temperature"),
93
+ "top_p": conf().get("top_p", 1),
94
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
95
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
96
+ }
97
+ file_id = context.kwargs.get("file_id")
98
+ if file_id:
99
+ body["file_id"] = file_id
100
+ logger.info(f"[LINKAI] query={query}, app_code={app_code}, model={body.get('model')}, file_id={file_id}")
101
+ headers = {"Authorization": "Bearer " + linkai_api_key}
102
+
103
+ # do http request
104
+ base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
105
+ res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
106
+ timeout=conf().get("request_timeout", 180))
107
+ if res.status_code == 200:
108
+ # execute success
109
+ response = res.json()
110
+ reply_content = response["choices"][0]["message"]["content"]
111
+ total_tokens = response["usage"]["total_tokens"]
112
+ logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
113
+ self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
114
+
115
+ agent_suffix = self._fetch_agent_suffix(response)
116
+ if agent_suffix:
117
+ reply_content += agent_suffix
118
+ if not agent_suffix:
119
+ knowledge_suffix = self._fetch_knowledge_search_suffix(response)
120
+ if knowledge_suffix:
121
+ reply_content += knowledge_suffix
122
+ # image process
123
+ if response["choices"][0].get("img_urls"):
124
+ thread = threading.Thread(target=self._send_image, args=(context.get("channel"), context, response["choices"][0].get("img_urls")))
125
+ thread.start()
126
+ if response["choices"][0].get("text_content"):
127
+ reply_content = response["choices"][0].get("text_content")
128
+ return Reply(ReplyType.TEXT, reply_content)
129
+
130
+ else:
131
+ response = res.json()
132
+ error = response.get("error")
133
+ logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
134
+ f"msg={error.get('message')}, type={error.get('type')}")
135
+
136
+ if res.status_code >= 500:
137
+ # server error, need retry
138
+ time.sleep(2)
139
+ logger.warn(f"[LINKAI] do retry, times={retry_count}")
140
+ return self._chat(query, context, retry_count + 1)
141
+
142
+ return Reply(ReplyType.TEXT, "提问太快啦,请休息一下再问我吧")
143
+
144
+ except Exception as e:
145
+ logger.exception(e)
146
+ # retry
147
+ time.sleep(2)
148
+ logger.warn(f"[LINKAI] do retry, times={retry_count}")
149
+ return self._chat(query, context, retry_count + 1)
150
+
151
+ def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
152
+ try:
153
+ enable_image_input = False
154
+ app_info = self._fetch_app_info(app_code)
155
+ if not app_info:
156
+ logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
157
+ return None
158
+ plugins = app_info.get("data").get("plugins")
159
+ for plugin in plugins:
160
+ if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
161
+ enable_image_input = True
162
+ if not enable_image_input:
163
+ return
164
+ msg = img_cache.get("msg")
165
+ path = img_cache.get("path")
166
+ msg.prepare()
167
+ logger.info(f"[LinkAI] query with images, path={path}")
168
+ messages = self._build_vision_msg(query, path)
169
+ memory.USER_IMAGE_CACHE[session_id] = None
170
+ return messages
171
+ except Exception as e:
172
+ logger.exception(e)
173
+
174
+ def _find_group_mapping_code(self, context):
175
+ try:
176
+ if context.kwargs.get("isgroup"):
177
+ group_name = context.kwargs.get("msg").from_user_nickname
178
+ if config.plugin_config and config.plugin_config.get("linkai"):
179
+ linkai_config = config.plugin_config.get("linkai")
180
+ group_mapping = linkai_config.get("group_app_map")
181
+ if group_mapping and group_name:
182
+ return group_mapping.get(group_name)
183
+ except Exception as e:
184
+ logger.exception(e)
185
+ return None
186
+
187
+ def _build_vision_msg(self, query: str, path: str):
188
+ try:
189
+ suffix = utils.get_path_suffix(path)
190
+ with open(path, "rb") as file:
191
+ base64_str = base64.b64encode(file.read()).decode('utf-8')
192
+ messages = [{
193
+ "role": "user",
194
+ "content": [
195
+ {
196
+ "type": "text",
197
+ "text": query
198
+ },
199
+ {
200
+ "type": "image_url",
201
+ "image_url": {
202
+ "url": f"data:image/{suffix};base64,{base64_str}"
203
+ }
204
+ }
205
+ ]
206
+ }]
207
+ return messages
208
+ except Exception as e:
209
+ logger.exception(e)
210
+
211
+ def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
212
+ if retry_count >= 2:
213
+ # exit from retry 2 times
214
+ logger.warn("[LINKAI] failed after maximum number of retry times")
215
+ return {
216
+ "total_tokens": 0,
217
+ "completion_tokens": 0,
218
+ "content": "请再问我一次吧"
219
+ }
220
+
221
+ try:
222
+ body = {
223
+ "app_code": app_code,
224
+ "messages": session.messages,
225
+ "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
226
+ "temperature": conf().get("temperature"),
227
+ "top_p": conf().get("top_p", 1),
228
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
229
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
230
+ }
231
+ if self.args.get("max_tokens"):
232
+ body["max_tokens"] = self.args.get("max_tokens")
233
+ headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
234
+
235
+ # do http request
236
+ base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
237
+ res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
238
+ timeout=conf().get("request_timeout", 180))
239
+ if res.status_code == 200:
240
+ # execute success
241
+ response = res.json()
242
+ reply_content = response["choices"][0]["message"]["content"]
243
+ total_tokens = response["usage"]["total_tokens"]
244
+ logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
245
+ return {
246
+ "total_tokens": total_tokens,
247
+ "completion_tokens": response["usage"]["completion_tokens"],
248
+ "content": reply_content,
249
+ }
250
+
251
+ else:
252
+ response = res.json()
253
+ error = response.get("error")
254
+ logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
255
+ f"msg={error.get('message')}, type={error.get('type')}")
256
+
257
+ if res.status_code >= 500:
258
+ # server error, need retry
259
+ time.sleep(2)
260
+ logger.warn(f"[LINKAI] do retry, times={retry_count}")
261
+ return self.reply_text(session, app_code, retry_count + 1)
262
+
263
+ return {
264
+ "total_tokens": 0,
265
+ "completion_tokens": 0,
266
+ "content": "提问太快啦,请休息一下再问我吧"
267
+ }
268
+
269
+ except Exception as e:
270
+ logger.exception(e)
271
+ # retry
272
+ time.sleep(2)
273
+ logger.warn(f"[LINKAI] do retry, times={retry_count}")
274
+ return self.reply_text(session, app_code, retry_count + 1)
275
+
276
+ def _fetch_app_info(self, app_code: str):
277
+ headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
278
+ # do http request
279
+ base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
280
+ params = {"app_code": app_code}
281
+ res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
282
+ if res.status_code == 200:
283
+ return res.json()
284
+ else:
285
+ logger.warning(f"[LinkAI] find app info exception, res={res}")
286
+
287
+ def create_img(self, query, retry_count=0, api_key=None):
288
+ try:
289
+ logger.info("[LinkImage] image_query={}".format(query))
290
+ headers = {
291
+ "Content-Type": "application/json",
292
+ "Authorization": f"Bearer {conf().get('linkai_api_key')}"
293
+ }
294
+ data = {
295
+ "prompt": query,
296
+ "n": 1,
297
+ "model": conf().get("text_to_image") or "dall-e-2",
298
+ "response_format": "url",
299
+ "img_proxy": conf().get("image_proxy")
300
+ }
301
+ url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/images/generations"
302
+ res = requests.post(url, headers=headers, json=data, timeout=(5, 90))
303
+ t2 = time.time()
304
+ image_url = res.json()["data"][0]["url"]
305
+ logger.info("[OPEN_AI] image_url={}".format(image_url))
306
+ return True, image_url
307
+
308
+ except Exception as e:
309
+ logger.error(format(e))
310
+ return False, "画图出现问题,请休息一下再问我吧"
311
+
312
+
313
+ def _fetch_knowledge_search_suffix(self, response) -> str:
314
+ try:
315
+ if response.get("knowledge_base"):
316
+ search_hit = response.get("knowledge_base").get("search_hit")
317
+ first_similarity = response.get("knowledge_base").get("first_similarity")
318
+ logger.info(f"[LINKAI] knowledge base, search_hit={search_hit}, first_similarity={first_similarity}")
319
+ plugin_config = pconf("linkai")
320
+ if plugin_config and plugin_config.get("knowledge_base") and plugin_config.get("knowledge_base").get("search_miss_text_enabled"):
321
+ search_miss_similarity = plugin_config.get("knowledge_base").get("search_miss_similarity")
322
+ search_miss_text = plugin_config.get("knowledge_base").get("search_miss_suffix")
323
+ if not search_hit:
324
+ return search_miss_text
325
+ if search_miss_similarity and float(search_miss_similarity) > first_similarity:
326
+ return search_miss_text
327
+ except Exception as e:
328
+ logger.exception(e)
329
+
330
+
331
+ def _fetch_agent_suffix(self, response):
332
+ try:
333
+ plugin_list = []
334
+ logger.debug(f"[LinkAgent] res={response}")
335
+ if response.get("agent") and response.get("agent").get("chain") and response.get("agent").get("need_show_plugin"):
336
+ chain = response.get("agent").get("chain")
337
+ suffix = "\n\n- - - - - - - - - - - -"
338
+ i = 0
339
+ for turn in chain:
340
+ plugin_name = turn.get('plugin_name')
341
+ suffix += "\n"
342
+ need_show_thought = response.get("agent").get("need_show_thought")
343
+ if turn.get("thought") and plugin_name and need_show_thought:
344
+ suffix += f"{turn.get('thought')}\n"
345
+ if plugin_name:
346
+ plugin_list.append(turn.get('plugin_name'))
347
+ suffix += f"{turn.get('plugin_icon')} {turn.get('plugin_name')}"
348
+ if turn.get('plugin_input'):
349
+ suffix += f":{turn.get('plugin_input')}"
350
+ if i < len(chain) - 1:
351
+ suffix += "\n"
352
+ i += 1
353
+ logger.info(f"[LinkAgent] use plugins: {plugin_list}")
354
+ return suffix
355
+ except Exception as e:
356
+ logger.exception(e)
357
+
358
+
359
+ def _send_image(self, channel, context, image_urls):
360
+ if not image_urls:
361
+ return
362
+ try:
363
+ for url in image_urls:
364
+ reply = Reply(ReplyType.IMAGE_URL, url)
365
+ channel.send(reply, context)
366
+ except Exception as e:
367
+ logger.error(e)
368
+
369
+
370
+ class LinkAISessionManager(SessionManager):
371
+ def session_msg_query(self, query, session_id):
372
+ session = self.build_session(session_id)
373
+ messages = session.messages + [{"role": "user", "content": query}]
374
+ return messages
375
+
376
+ def session_reply(self, reply, session_id, total_tokens=None, query=None):
377
+ session = self.build_session(session_id)
378
+ if query:
379
+ session.add_query(query)
380
+ session.add_reply(reply)
381
+ try:
382
+ max_tokens = conf().get("conversation_max_tokens", 2500)
383
+ tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
384
+ logger.debug(f"[LinkAI] chat history, before tokens={total_tokens}, now tokens={tokens_cnt}")
385
+ except Exception as e:
386
+ logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
387
+ return session
388
+
389
+
390
+ class LinkAISession(ChatGPTSession):
391
+ def calc_tokens(self):
392
+ if not self.messages:
393
+ return 0
394
+ return len(str(self.messages))
395
+
396
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
397
+ cur_tokens = self.calc_tokens()
398
+ if cur_tokens > max_tokens:
399
+ for i in range(0, len(self.messages)):
400
+ if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
401
+ self.messages.pop(i)
402
+ self.messages.pop(i - 1)
403
+ return self.calc_tokens()
404
+ return cur_tokens
bot/openai/open_ai_bot.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ import time
4
+
5
+ import openai
6
+ import openai.error
7
+
8
+ from bot.bot import Bot
9
+ from bot.openai.open_ai_image import OpenAIImage
10
+ from bot.openai.open_ai_session import OpenAISession
11
+ from bot.session_manager import SessionManager
12
+ from bridge.context import ContextType
13
+ from bridge.reply import Reply, ReplyType
14
+ from common.log import logger
15
+ from config import conf
16
+
17
+ user_session = dict()
18
+
19
+
20
+ # OpenAI对话模型API (可用)
21
+ class OpenAIBot(Bot, OpenAIImage):
22
+ def __init__(self):
23
+ super().__init__()
24
+ openai.api_key = conf().get("open_ai_api_key")
25
+ if conf().get("open_ai_api_base"):
26
+ openai.api_base = conf().get("open_ai_api_base")
27
+ proxy = conf().get("proxy")
28
+ if proxy:
29
+ openai.proxy = proxy
30
+
31
+ self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
32
+ self.args = {
33
+ "model": conf().get("model") or "text-davinci-003", # 对话模型的名称
34
+ "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
35
+ "max_tokens": 1200, # 回复最大的字符数
36
+ "top_p": 1,
37
+ "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
38
+ "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
39
+ "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
40
+ "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
41
+ "stop": ["\n\n\n"],
42
+ }
43
+
44
+ def reply(self, query, context=None):
45
+ # acquire reply content
46
+ if context and context.type:
47
+ if context.type == ContextType.TEXT:
48
+ logger.info("[OPEN_AI] query={}".format(query))
49
+ session_id = context["session_id"]
50
+ reply = None
51
+ if query == "#清除记忆":
52
+ self.sessions.clear_session(session_id)
53
+ reply = Reply(ReplyType.INFO, "记忆已清除")
54
+ elif query == "#清除所有":
55
+ self.sessions.clear_all_session()
56
+ reply = Reply(ReplyType.INFO, "所有人记忆已清除")
57
+ else:
58
+ session = self.sessions.session_query(query, session_id)
59
+ result = self.reply_text(session)
60
+ total_tokens, completion_tokens, reply_content = (
61
+ result["total_tokens"],
62
+ result["completion_tokens"],
63
+ result["content"],
64
+ )
65
+ logger.debug(
66
+ "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
67
+ )
68
+
69
+ if total_tokens == 0:
70
+ reply = Reply(ReplyType.ERROR, reply_content)
71
+ else:
72
+ self.sessions.session_reply(reply_content, session_id, total_tokens)
73
+ reply = Reply(ReplyType.TEXT, reply_content)
74
+ return reply
75
+ elif context.type == ContextType.IMAGE_CREATE:
76
+ ok, retstring = self.create_img(query, 0)
77
+ reply = None
78
+ if ok:
79
+ reply = Reply(ReplyType.IMAGE_URL, retstring)
80
+ else:
81
+ reply = Reply(ReplyType.ERROR, retstring)
82
+ return reply
83
+
84
+ def reply_text(self, session: OpenAISession, retry_count=0):
85
+ try:
86
+ response = openai.Completion.create(prompt=str(session), **self.args)
87
+ res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
88
+ total_tokens = response["usage"]["total_tokens"]
89
+ completion_tokens = response["usage"]["completion_tokens"]
90
+ logger.info("[OPEN_AI] reply={}".format(res_content))
91
+ return {
92
+ "total_tokens": total_tokens,
93
+ "completion_tokens": completion_tokens,
94
+ "content": res_content,
95
+ }
96
+ except Exception as e:
97
+ need_retry = retry_count < 2
98
+ result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
99
+ if isinstance(e, openai.error.RateLimitError):
100
+ logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
101
+ result["content"] = "提问太快啦,请休息一下再问我吧"
102
+ if need_retry:
103
+ time.sleep(20)
104
+ elif isinstance(e, openai.error.Timeout):
105
+ logger.warn("[OPEN_AI] Timeout: {}".format(e))
106
+ result["content"] = "我没有收到你的消息"
107
+ if need_retry:
108
+ time.sleep(5)
109
+ elif isinstance(e, openai.error.APIConnectionError):
110
+ logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
111
+ need_retry = False
112
+ result["content"] = "我连接不到你的网络"
113
+ else:
114
+ logger.warn("[OPEN_AI] Exception: {}".format(e))
115
+ need_retry = False
116
+ self.sessions.clear_session(session.session_id)
117
+
118
+ if need_retry:
119
+ logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
120
+ return self.reply_text(session, retry_count + 1)
121
+ else:
122
+ return result
bot/openai/open_ai_image.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import openai
4
+ import openai.error
5
+
6
+ from common.log import logger
7
+ from common.token_bucket import TokenBucket
8
+ from config import conf
9
+
10
+
11
+ # OPENAI提供的画图接口
12
+ class OpenAIImage(object):
13
+ def __init__(self):
14
+ openai.api_key = conf().get("open_ai_api_key")
15
+ if conf().get("rate_limit_dalle"):
16
+ self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
17
+
18
+ def create_img(self, query, retry_count=0, api_key=None):
19
+ try:
20
+ if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
21
+ return False, "请求太快了,请休息一下再问我吧"
22
+ logger.info("[OPEN_AI] image_query={}".format(query))
23
+ response = openai.Image.create(
24
+ api_key=api_key,
25
+ prompt=query, # 图片描述
26
+ n=1, # 每次生成图片的数量
27
+ model=conf().get("text_to_image") or "dall-e-2",
28
+ # size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
29
+ )
30
+ image_url = response["data"][0]["url"]
31
+ logger.info("[OPEN_AI] image_url={}".format(image_url))
32
+ return True, image_url
33
+ except openai.error.RateLimitError as e:
34
+ logger.warn(e)
35
+ if retry_count < 1:
36
+ time.sleep(5)
37
+ logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
38
+ return self.create_img(query, retry_count + 1)
39
+ else:
40
+ return False, "画图出现问题,请休息一下再问我吧"
41
+ except Exception as e:
42
+ logger.exception(e)
43
+ return False, "画图出现问题,请休息一下再问我吧"
bot/openai/open_ai_session.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bot.session_manager import Session
2
+ from common.log import logger
3
+
4
+
5
+ class OpenAISession(Session):
6
+ def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
7
+ super().__init__(session_id, system_prompt)
8
+ self.model = model
9
+ self.reset()
10
+
11
+ def __str__(self):
12
+ # 构造对话模型的输入
13
+ """
14
+ e.g. Q: xxx
15
+ A: xxx
16
+ Q: xxx
17
+ """
18
+ prompt = ""
19
+ for item in self.messages:
20
+ if item["role"] == "system":
21
+ prompt += item["content"] + "<|endoftext|>\n\n\n"
22
+ elif item["role"] == "user":
23
+ prompt += "Q: " + item["content"] + "\n"
24
+ elif item["role"] == "assistant":
25
+ prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
26
+
27
+ if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
28
+ prompt += "A: "
29
+ return prompt
30
+
31
+ def discard_exceeding(self, max_tokens, cur_tokens=None):
32
+ precise = True
33
+ try:
34
+ cur_tokens = self.calc_tokens()
35
+ except Exception as e:
36
+ precise = False
37
+ if cur_tokens is None:
38
+ raise e
39
+ logger.debug("Exception when counting tokens precisely for query: {}".format(e))
40
+ while cur_tokens > max_tokens:
41
+ if len(self.messages) > 1:
42
+ self.messages.pop(0)
43
+ elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
44
+ self.messages.pop(0)
45
+ if precise:
46
+ cur_tokens = self.calc_tokens()
47
+ else:
48
+ cur_tokens = len(str(self))
49
+ break
50
+ elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
51
+ logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
52
+ break
53
+ else:
54
+ logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
55
+ break
56
+ if precise:
57
+ cur_tokens = self.calc_tokens()
58
+ else:
59
+ cur_tokens = len(str(self))
60
+ return cur_tokens
61
+
62
+ def calc_tokens(self):
63
+ return num_tokens_from_string(str(self), self.model)
64
+
65
+
66
+ # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
67
+ def num_tokens_from_string(string: str, model: str) -> int:
68
+ """Returns the number of tokens in a text string."""
69
+ import tiktoken
70
+
71
+ encoding = tiktoken.encoding_for_model(model)
72
+ num_tokens = len(encoding.encode(string, disallowed_special=()))
73
+ return num_tokens
bot/session_manager.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from common.expired_dict import ExpiredDict
2
+ from common.log import logger
3
+ from config import conf
4
+
5
+
6
+ class Session(object):
7
+ def __init__(self, session_id, system_prompt=None):
8
+ self.session_id = session_id
9
+ self.messages = []
10
+ if system_prompt is None:
11
+ self.system_prompt = conf().get("character_desc", "")
12
+ else:
13
+ self.system_prompt = system_prompt
14
+
15
+ # 重置会话
16
+ def reset(self):
17
+ system_item = {"role": "system", "content": self.system_prompt}
18
+ self.messages = [system_item]
19
+
20
+ def set_system_prompt(self, system_prompt):
21
+ self.system_prompt = system_prompt
22
+ self.reset()
23
+
24
+ def add_query(self, query):
25
+ user_item = {"role": "user", "content": query}
26
+ self.messages.append(user_item)
27
+
28
+ def add_reply(self, reply):
29
+ assistant_item = {"role": "assistant", "content": reply}
30
+ self.messages.append(assistant_item)
31
+
32
+ def discard_exceeding(self, max_tokens=None, cur_tokens=None):
33
+ raise NotImplementedError
34
+
35
+ def calc_tokens(self):
36
+ raise NotImplementedError
37
+
38
+
39
+ class SessionManager(object):
40
+ def __init__(self, sessioncls, **session_args):
41
+ if conf().get("expires_in_seconds"):
42
+ sessions = ExpiredDict(conf().get("expires_in_seconds"))
43
+ else:
44
+ sessions = dict()
45
+ self.sessions = sessions
46
+ self.sessioncls = sessioncls
47
+ self.session_args = session_args
48
+
49
+ def build_session(self, session_id, system_prompt=None):
50
+ """
51
+ 如果session_id不在sessions中,创建一个新的session并添加到sessions中
52
+ 如果system_prompt不会空,会更新session的system_prompt并重置session
53
+ """
54
+ if session_id is None:
55
+ return self.sessioncls(session_id, system_prompt, **self.session_args)
56
+
57
+ if session_id not in self.sessions:
58
+ self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
59
+ elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
60
+ self.sessions[session_id].set_system_prompt(system_prompt)
61
+ session = self.sessions[session_id]
62
+ return session
63
+
64
+ def session_query(self, query, session_id):
65
+ session = self.build_session(session_id)
66
+ session.add_query(query)
67
+ try:
68
+ max_tokens = conf().get("conversation_max_tokens", 1000)
69
+ total_tokens = session.discard_exceeding(max_tokens, None)
70
+ logger.debug("prompt tokens used={}".format(total_tokens))
71
+ except Exception as e:
72
+ logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
73
+ return session
74
+
75
+ def session_reply(self, reply, session_id, total_tokens=None):
76
+ session = self.build_session(session_id)
77
+ session.add_reply(reply)
78
+ try:
79
+ max_tokens = conf().get("conversation_max_tokens", 1000)
80
+ tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
81
+ logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
82
+ except Exception as e:
83
+ logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
84
+ return session
85
+
86
+ def clear_session(self, session_id):
87
+ if session_id in self.sessions:
88
+ del self.sessions[session_id]
89
+
90
+ def clear_all_session(self):
91
+ self.sessions.clear()
bot/tongyi/tongyi_qwen_bot.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ import json
4
+ import time
5
+ from typing import List, Tuple
6
+
7
+ import openai
8
+ import openai.error
9
+ import broadscope_bailian
10
+ from broadscope_bailian import ChatQaMessage
11
+
12
+ from bot.bot import Bot
13
+ from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
14
+ from bot.session_manager import SessionManager
15
+ from bridge.context import ContextType
16
+ from bridge.reply import Reply, ReplyType
17
+ from common.log import logger
18
+ from config import conf, load_config
19
+
20
+ class TongyiQwenBot(Bot):
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.access_key_id = conf().get("qwen_access_key_id")
24
+ self.access_key_secret = conf().get("qwen_access_key_secret")
25
+ self.agent_key = conf().get("qwen_agent_key")
26
+ self.app_id = conf().get("qwen_app_id")
27
+ self.node_id = conf().get("qwen_node_id") or ""
28
+ self.api_key_client = broadscope_bailian.AccessTokenClient(access_key_id=self.access_key_id, access_key_secret=self.access_key_secret)
29
+ self.api_key_expired_time = self.set_api_key()
30
+ self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "qwen")
31
+ self.temperature = conf().get("temperature", 0.2) # 值在[0,1]之间,越大表示回复越具有不确定性
32
+ self.top_p = conf().get("top_p", 1)
33
+
34
+ def reply(self, query, context=None):
35
+ # acquire reply content
36
+ if context.type == ContextType.TEXT:
37
+ logger.info("[TONGYI] query={}".format(query))
38
+
39
+ session_id = context["session_id"]
40
+ reply = None
41
+ clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
42
+ if query in clear_memory_commands:
43
+ self.sessions.clear_session(session_id)
44
+ reply = Reply(ReplyType.INFO, "记忆已清除")
45
+ elif query == "#清除所有":
46
+ self.sessions.clear_all_session()
47
+ reply = Reply(ReplyType.INFO, "所有人记忆已清除")
48
+ elif query == "#更新配置":
49
+ load_config()
50
+ reply = Reply(ReplyType.INFO, "配置已更新")
51
+ if reply:
52
+ return reply
53
+ session = self.sessions.session_query(query, session_id)
54
+ logger.debug("[TONGYI] session query={}".format(session.messages))
55
+
56
+ reply_content = self.reply_text(session)
57
+ logger.debug(
58
+ "[TONGYI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
59
+ session.messages,
60
+ session_id,
61
+ reply_content["content"],
62
+ reply_content["completion_tokens"],
63
+ )
64
+ )
65
+ if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
66
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
67
+ elif reply_content["completion_tokens"] > 0:
68
+ self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
69
+ reply = Reply(ReplyType.TEXT, reply_content["content"])
70
+ else:
71
+ reply = Reply(ReplyType.ERROR, reply_content["content"])
72
+ logger.debug("[TONGYI] reply {} used 0 tokens.".format(reply_content))
73
+ return reply
74
+
75
+ else:
76
+ reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
77
+ return reply
78
+
79
+ def reply_text(self, session: BaiduWenxinSession, retry_count=0) -> dict:
80
+ """
81
+ call bailian's ChatCompletion to get the answer
82
+ :param session: a conversation session
83
+ :param retry_count: retry count
84
+ :return: {}
85
+ """
86
+ try:
87
+ prompt, history = self.convert_messages_format(session.messages)
88
+ self.update_api_key_if_expired()
89
+ # NOTE 阿里百炼的call()函数参数比较奇怪, top_k参数表示top_p, top_p参数表示temperature, 可以参考文档 https://help.aliyun.com/document_detail/2587502.htm
90
+ response = broadscope_bailian.Completions().call(app_id=self.app_id, prompt=prompt, history=history, top_k=self.top_p, top_p=self.temperature)
91
+ completion_content = self.get_completion_content(response, self.node_id)
92
+ completion_tokens, total_tokens = self.calc_tokens(session.messages, completion_content)
93
+ return {
94
+ "total_tokens": total_tokens,
95
+ "completion_tokens": completion_tokens,
96
+ "content": completion_content,
97
+ }
98
+ except Exception as e:
99
+ need_retry = retry_count < 2
100
+ result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
101
+ if isinstance(e, openai.error.RateLimitError):
102
+ logger.warn("[TONGYI] RateLimitError: {}".format(e))
103
+ result["content"] = "提问太快啦,请休息一下再问我吧"
104
+ if need_retry:
105
+ time.sleep(20)
106
+ elif isinstance(e, openai.error.Timeout):
107
+ logger.warn("[TONGYI] Timeout: {}".format(e))
108
+ result["content"] = "我没有收到你的消息"
109
+ if need_retry:
110
+ time.sleep(5)
111
+ elif isinstance(e, openai.error.APIError):
112
+ logger.warn("[TONGYI] Bad Gateway: {}".format(e))
113
+ result["content"] = "请再问我一次"
114
+ if need_retry:
115
+ time.sleep(10)
116
+ elif isinstance(e, openai.error.APIConnectionError):
117
+ logger.warn("[TONGYI] APIConnectionError: {}".format(e))
118
+ need_retry = False
119
+ result["content"] = "我连接不到你的网络"
120
+ else:
121
+ logger.exception("[TONGYI] Exception: {}".format(e))
122
+ need_retry = False
123
+ self.sessions.clear_session(session.session_id)
124
+
125
+ if need_retry:
126
+ logger.warn("[TONGYI] 第{}次重试".format(retry_count + 1))
127
+ return self.reply_text(session, retry_count + 1)
128
+ else:
129
+ return result
130
+
131
+ def set_api_key(self):
132
+ api_key, expired_time = self.api_key_client.create_token(agent_key=self.agent_key)
133
+ broadscope_bailian.api_key = api_key
134
+ return expired_time
135
+ def update_api_key_if_expired(self):
136
+ if time.time() > self.api_key_expired_time:
137
+ self.api_key_expired_time = self.set_api_key()
138
+
139
+ def convert_messages_format(self, messages) -> Tuple[str, List[ChatQaMessage]]:
140
+ history = []
141
+ user_content = ''
142
+ assistant_content = ''
143
+ for message in messages:
144
+ role = message.get('role')
145
+ if role == 'user':
146
+ user_content += message.get('content')
147
+ elif role == 'assistant':
148
+ assistant_content = message.get('content')
149
+ history.append(ChatQaMessage(user_content, assistant_content))
150
+ user_content = ''
151
+ assistant_content = ''
152
+ if user_content == '':
153
+ raise Exception('no user message')
154
+ return user_content, history
155
+
156
+ def get_completion_content(self, response, node_id):
157
+ text = response['Data']['Text']
158
+ if node_id == '':
159
+ return text
160
+ # TODO: 当使用流程编排创建大模型应用时,响应结构如下,最终结果在['finalResult'][node_id]['response']['text']中,暂时先这么写
161
+ # {
162
+ # 'Success': True,
163
+ # 'Code': None,
164
+ # 'Message': None,
165
+ # 'Data': {
166
+ # 'ResponseId': '9822f38dbacf4c9b8daf5ca03a2daf15',
167
+ # 'SessionId': 'session_id',
168
+ # 'Text': '{"finalResult":{"LLM_T7islK":{"params":{"modelId":"qwen-plus-v1","prompt":"${systemVars.query}${bizVars.Text}"},"response":{"text":"作为一个AI语言模型,我没有年龄,因为我没有生日。\n我只是一个程序,没有生命和身体。"}}}}',
169
+ # 'Thoughts': [],
170
+ # 'Debug': {},
171
+ # 'DocReferences': []
172
+ # },
173
+ # 'RequestId': '8e11d31551ce4c3f83f49e6e0dd998b0',
174
+ # 'Failed': None
175
+ # }
176
+ text_dict = json.loads(text)
177
+ completion_content = text_dict['finalResult'][node_id]['response']['text']
178
+ return completion_content
179
+
180
+ def calc_tokens(self, messages, completion_content):
181
+ completion_tokens = len(completion_content)
182
+ prompt_tokens = 0
183
+ for message in messages:
184
+ prompt_tokens += len(message["content"])
185
+ return completion_tokens, prompt_tokens + completion_tokens
bot/xunfei/xunfei_spark_bot.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ import requests, json
4
+ from bot.bot import Bot
5
+ from bot.session_manager import SessionManager
6
+ from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
7
+ from bridge.context import ContextType, Context
8
+ from bridge.reply import Reply, ReplyType
9
+ from common.log import logger
10
+ from config import conf
11
+ from common import const
12
+ import time
13
+ import _thread as thread
14
+ import datetime
15
+ from datetime import datetime
16
+ from wsgiref.handlers import format_date_time
17
+ from urllib.parse import urlencode
18
+ import base64
19
+ import ssl
20
+ import hashlib
21
+ import hmac
22
+ import json
23
+ from time import mktime
24
+ from urllib.parse import urlparse
25
+ import websocket
26
+ import queue
27
+ import threading
28
+ import random
29
+
30
+ # 消息队列 map
31
+ queue_map = dict()
32
+
33
+ # 响应队列 map
34
+ reply_map = dict()
35
+
36
+
37
+ class XunFeiBot(Bot):
38
+ def __init__(self):
39
+ super().__init__()
40
+ self.app_id = conf().get("xunfei_app_id")
41
+ self.api_key = conf().get("xunfei_api_key")
42
+ self.api_secret = conf().get("xunfei_api_secret")
43
+ # 默认使用v2.0版本: "generalv2"
44
+ # v1.5版本为 "general"
45
+ # v3.0版本为: "generalv3"
46
+ self.domain = "generalv3"
47
+ # 默认使用v2.0版本: "ws://spark-api.xf-yun.com/v2.1/chat"
48
+ # v1.5版本为: "ws://spark-api.xf-yun.com/v1.1/chat"
49
+ # v3.0版本为: "ws://spark-api.xf-yun.com/v3.1/chat"
50
+ self.spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"
51
+ self.host = urlparse(self.spark_url).netloc
52
+ self.path = urlparse(self.spark_url).path
53
+ # 和wenxin使用相同的session机制
54
+ self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
55
+
56
+ def reply(self, query, context: Context = None) -> Reply:
57
+ if context.type == ContextType.TEXT:
58
+ logger.info("[XunFei] query={}".format(query))
59
+ session_id = context["session_id"]
60
+ request_id = self.gen_request_id(session_id)
61
+ reply_map[request_id] = ""
62
+ session = self.sessions.session_query(query, session_id)
63
+ threading.Thread(target=self.create_web_socket,
64
+ args=(session.messages, request_id)).start()
65
+ depth = 0
66
+ time.sleep(0.1)
67
+ t1 = time.time()
68
+ usage = {}
69
+ while depth <= 300:
70
+ try:
71
+ data_queue = queue_map.get(request_id)
72
+ if not data_queue:
73
+ depth += 1
74
+ time.sleep(0.1)
75
+ continue
76
+ data_item = data_queue.get(block=True, timeout=0.1)
77
+ if data_item.is_end:
78
+ # 请求结束
79
+ del queue_map[request_id]
80
+ if data_item.reply:
81
+ reply_map[request_id] += data_item.reply
82
+ usage = data_item.usage
83
+ break
84
+
85
+ reply_map[request_id] += data_item.reply
86
+ depth += 1
87
+ except Exception as e:
88
+ depth += 1
89
+ continue
90
+ t2 = time.time()
91
+ logger.info(
92
+ f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}"
93
+ )
94
+ self.sessions.session_reply(reply_map[request_id], session_id,
95
+ usage.get("total_tokens"))
96
+ reply = Reply(ReplyType.TEXT, reply_map[request_id])
97
+ del reply_map[request_id]
98
+ return reply
99
+ else:
100
+ reply = Reply(ReplyType.ERROR,
101
+ "Bot不支持处理{}类型的消息".format(context.type))
102
+ return reply
103
+
104
+ def create_web_socket(self, prompt, session_id, temperature=0.5):
105
+ logger.info(f"[XunFei] start connect, prompt={prompt}")
106
+ websocket.enableTrace(False)
107
+ wsUrl = self.create_url()
108
+ ws = websocket.WebSocketApp(wsUrl,
109
+ on_message=on_message,
110
+ on_error=on_error,
111
+ on_close=on_close,
112
+ on_open=on_open)
113
+ data_queue = queue.Queue(1000)
114
+ queue_map[session_id] = data_queue
115
+ ws.appid = self.app_id
116
+ ws.question = prompt
117
+ ws.domain = self.domain
118
+ ws.session_id = session_id
119
+ ws.temperature = temperature
120
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
121
+
122
+ def gen_request_id(self, session_id: str):
123
+ return session_id + "_" + str(int(time.time())) + "" + str(
124
+ random.randint(0, 100))
125
+
126
+ # 生成url
127
+ def create_url(self):
128
+ # 生成RFC1123格式的时间戳
129
+ now = datetime.now()
130
+ date = format_date_time(mktime(now.timetuple()))
131
+
132
+ # 拼接字符串
133
+ signature_origin = "host: " + self.host + "\n"
134
+ signature_origin += "date: " + date + "\n"
135
+ signature_origin += "GET " + self.path + " HTTP/1.1"
136
+
137
+ # 进行hmac-sha256进行加密
138
+ signature_sha = hmac.new(self.api_secret.encode('utf-8'),
139
+ signature_origin.encode('utf-8'),
140
+ digestmod=hashlib.sha256).digest()
141
+
142
+ signature_sha_base64 = base64.b64encode(signature_sha).decode(
143
+ encoding='utf-8')
144
+
145
+ authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
146
+ f'signature="{signature_sha_base64}"'
147
+
148
+ authorization = base64.b64encode(
149
+ authorization_origin.encode('utf-8')).decode(encoding='utf-8')
150
+
151
+ # 将请求的鉴权参数组合为字典
152
+ v = {"authorization": authorization, "date": date, "host": self.host}
153
+ # 拼接鉴权参数,生成url
154
+ url = self.spark_url + '?' + urlencode(v)
155
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
156
+ return url
157
+
158
+ def gen_params(self, appid, domain, question):
159
+ """
160
+ 通过appid和用户的提问来生成请参数
161
+ """
162
+ data = {
163
+ "header": {
164
+ "app_id": appid,
165
+ "uid": "1234"
166
+ },
167
+ "parameter": {
168
+ "chat": {
169
+ "domain": domain,
170
+ "random_threshold": 0.5,
171
+ "max_tokens": 2048,
172
+ "auditing": "default"
173
+ }
174
+ },
175
+ "payload": {
176
+ "message": {
177
+ "text": question
178
+ }
179
+ }
180
+ }
181
+ return data
182
+
183
+
184
+ class ReplyItem:
185
+ def __init__(self, reply, usage=None, is_end=False):
186
+ self.is_end = is_end
187
+ self.reply = reply
188
+ self.usage = usage
189
+
190
+
191
+ # 收到websocket错误的处理
192
+ def on_error(ws, error):
193
+ logger.error(f"[XunFei] error: {str(error)}")
194
+
195
+
196
+ # 收到websocket关闭的处理
197
+ def on_close(ws, one, two):
198
+ data_queue = queue_map.get(ws.session_id)
199
+ data_queue.put("END")
200
+
201
+
202
+ # 收到websocket连接建立的处理
203
+ def on_open(ws):
204
+ logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
205
+ thread.start_new_thread(run, (ws, ))
206
+
207
+
208
+ def run(ws, *args):
209
+ data = json.dumps(
210
+ gen_params(appid=ws.appid,
211
+ domain=ws.domain,
212
+ question=ws.question,
213
+ temperature=ws.temperature))
214
+ ws.send(data)
215
+
216
+
217
+ # Websocket 操作
218
+ # 收到websocket消息的处理
219
+ def on_message(ws, message):
220
+ data = json.loads(message)
221
+ code = data['header']['code']
222
+ if code != 0:
223
+ logger.error(f'请求错误: {code}, {data}')
224
+ ws.close()
225
+ else:
226
+ choices = data["payload"]["choices"]
227
+ status = choices["status"]
228
+ content = choices["text"][0]["content"]
229
+ data_queue = queue_map.get(ws.session_id)
230
+ if not data_queue:
231
+ logger.error(
232
+ f"[XunFei] can't find data queue, session_id={ws.session_id}")
233
+ return
234
+ reply_item = ReplyItem(content)
235
+ if status == 2:
236
+ usage = data["payload"].get("usage")
237
+ reply_item = ReplyItem(content, usage)
238
+ reply_item.is_end = True
239
+ ws.close()
240
+ data_queue.put(reply_item)
241
+
242
+
243
+ def gen_params(appid, domain, question, temperature=0.5):
244
+ """
245
+ 通过appid和用户的提问来生成请参数
246
+ """
247
+ data = {
248
+ "header": {
249
+ "app_id": appid,
250
+ "uid": "1234"
251
+ },
252
+ "parameter": {
253
+ "chat": {
254
+ "domain": domain,
255
+ "temperature": temperature,
256
+ "random_threshold": 0.5,
257
+ "max_tokens": 2048,
258
+ "auditing": "default"
259
+ }
260
+ },
261
+ "payload": {
262
+ "message": {
263
+ "text": question
264
+ }
265
+ }
266
+ }
267
+ return data
bridge/bridge.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bot.bot_factory import create_bot
2
+ from bridge.context import Context
3
+ from bridge.reply import Reply
4
+ from common import const
5
+ from common.log import logger
6
+ from common.singleton import singleton
7
+ from config import conf
8
+ from translate.factory import create_translator
9
+ from voice.factory import create_voice
10
+
11
+
12
+ @singleton
13
+ class Bridge(object):
14
+ def __init__(self):
15
+ self.btype = {
16
+ "chat": const.CHATGPT,
17
+ "voice_to_text": conf().get("voice_to_text", "openai"),
18
+ "text_to_voice": conf().get("text_to_voice", "google"),
19
+ "translate": conf().get("translate", "baidu"),
20
+ }
21
+ model_type = conf().get("model") or const.GPT35
22
+ if model_type in ["text-davinci-003"]:
23
+ self.btype["chat"] = const.OPEN_AI
24
+ if conf().get("use_azure_chatgpt", False):
25
+ self.btype["chat"] = const.CHATGPTONAZURE
26
+ if model_type in ["wenxin", "wenxin-4"]:
27
+ self.btype["chat"] = const.BAIDU
28
+ if model_type in ["xunfei"]:
29
+ self.btype["chat"] = const.XUNFEI
30
+ if model_type in [const.QWEN]:
31
+ self.btype["chat"] = const.QWEN
32
+ if conf().get("use_linkai") and conf().get("linkai_api_key"):
33
+ self.btype["chat"] = const.LINKAI
34
+ if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
35
+ self.btype["voice_to_text"] = const.LINKAI
36
+ if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
37
+ self.btype["text_to_voice"] = const.LINKAI
38
+ if model_type in ["claude"]:
39
+ self.btype["chat"] = const.CLAUDEAI
40
+ self.bots = {}
41
+ self.chat_bots = {}
42
+
43
+ def get_bot(self, typename):
44
+ if self.bots.get(typename) is None:
45
+ logger.info("create bot {} for {}".format(self.btype[typename], typename))
46
+ if typename == "text_to_voice":
47
+ self.bots[typename] = create_voice(self.btype[typename])
48
+ elif typename == "voice_to_text":
49
+ self.bots[typename] = create_voice(self.btype[typename])
50
+ elif typename == "chat":
51
+ self.bots[typename] = create_bot(self.btype[typename])
52
+ elif typename == "translate":
53
+ self.bots[typename] = create_translator(self.btype[typename])
54
+ return self.bots[typename]
55
+
56
+ def get_bot_type(self, typename):
57
+ return self.btype[typename]
58
+
59
+ def fetch_reply_content(self, query, context: Context) -> Reply:
60
+ return self.get_bot("chat").reply(query, context)
61
+
62
+ def fetch_voice_to_text(self, voiceFile) -> Reply:
63
+ return self.get_bot("voice_to_text").voiceToText(voiceFile)
64
+
65
+ def fetch_text_to_voice(self, text) -> Reply:
66
+ return self.get_bot("text_to_voice").textToVoice(text)
67
+
68
+ def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
69
+ return self.get_bot("translate").translate(text, from_lang, to_lang)
70
+
71
+ def find_chat_bot(self, bot_type: str):
72
+ if self.chat_bots.get(bot_type) is None:
73
+ self.chat_bots[bot_type] = create_bot(bot_type)
74
+ return self.chat_bots.get(bot_type)
75
+
76
+ def reset_bot(self):
77
+ """
78
+ 重置bot路由
79
+ """
80
+ self.__init__()
bridge/context.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ from enum import Enum
4
+
5
+
6
+ class ContextType(Enum):
7
+ TEXT = 1 # 文本消息
8
+ VOICE = 2 # 音频消息
9
+ IMAGE = 3 # 图片消息
10
+ FILE = 4 # 文件信息
11
+ VIDEO = 5 # 视频信息
12
+ SHARING = 6 # 分享信息
13
+
14
+ IMAGE_CREATE = 10 # 创建图片命令
15
+ ACCEPT_FRIEND = 19 # 同意好友请求
16
+ JOIN_GROUP = 20 # 加入群聊
17
+ PATPAT = 21 # 拍了拍
18
+ FUNCTION = 22 # 函数调用
19
+ EXIT_GROUP = 23 #退出
20
+
21
+
22
+ def __str__(self):
23
+ return self.name
24
+
25
+
26
+ class Context:
27
+ def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
28
+ self.type = type
29
+ self.content = content
30
+ self.kwargs = kwargs
31
+
32
+ def __contains__(self, key):
33
+ if key == "type":
34
+ return self.type is not None
35
+ elif key == "content":
36
+ return self.content is not None
37
+ else:
38
+ return key in self.kwargs
39
+
40
+ def __getitem__(self, key):
41
+ if key == "type":
42
+ return self.type
43
+ elif key == "content":
44
+ return self.content
45
+ else:
46
+ return self.kwargs[key]
47
+
48
+ def get(self, key, default=None):
49
+ try:
50
+ return self[key]
51
+ except KeyError:
52
+ return default
53
+
54
+ def __setitem__(self, key, value):
55
+ if key == "type":
56
+ self.type = value
57
+ elif key == "content":
58
+ self.content = value
59
+ else:
60
+ self.kwargs[key] = value
61
+
62
+ def __delitem__(self, key):
63
+ if key == "type":
64
+ self.type = None
65
+ elif key == "content":
66
+ self.content = None
67
+ else:
68
+ del self.kwargs[key]
69
+
70
+ def __str__(self):
71
+ return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
bridge/reply.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ from enum import Enum
4
+
5
+
6
+ class ReplyType(Enum):
7
+ TEXT = 1 # 文本
8
+ VOICE = 2 # 音频文件
9
+ IMAGE = 3 # 图片文件
10
+ IMAGE_URL = 4 # 图片URL
11
+ VIDEO_URL = 5 # 视频URL
12
+ FILE = 6 # 文件
13
+ CARD = 7 # 微信名片,仅支持ntchat
14
+ InviteRoom = 8 # 邀请好友进群
15
+ INFO = 9
16
+ ERROR = 10
17
+ TEXT_ = 11 # 强制文本
18
+ VIDEO = 12
19
+ MINIAPP = 13 # 小程序
20
+
21
+ def __str__(self):
22
+ return self.name
23
+
24
+
25
+ class Reply:
26
+ def __init__(self, type: ReplyType = None, content=None):
27
+ self.type = type
28
+ self.content = content
29
+
30
+ def __str__(self):
31
+ return "Reply(type={}, content={})".format(self.type, self.content)
channel/channel.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Message sending channel abstract class
3
+ """
4
+
5
+ from bridge.bridge import Bridge
6
+ from bridge.context import Context
7
+ from bridge.reply import *
8
+
9
+
10
+ class Channel(object):
11
+ NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
12
+
13
+ def startup(self):
14
+ """
15
+ init channel
16
+ """
17
+ raise NotImplementedError
18
+
19
+ def handle_text(self, msg):
20
+ """
21
+ process received msg
22
+ :param msg: message object
23
+ """
24
+ raise NotImplementedError
25
+
26
+ # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
27
+ def send(self, reply: Reply, context: Context):
28
+ """
29
+ send message to user
30
+ :param msg: message content
31
+ :param receiver: receiver channel account
32
+ :return:
33
+ """
34
+ raise NotImplementedError
35
+
36
+ def build_reply_content(self, query, context: Context = None) -> Reply:
37
+ return Bridge().fetch_reply_content(query, context)
38
+
39
+ def build_voice_to_text(self, voice_file) -> Reply:
40
+ return Bridge().fetch_voice_to_text(voice_file)
41
+
42
+ def build_text_to_voice(self, text) -> Reply:
43
+ return Bridge().fetch_text_to_voice(text)
channel/channel_factory.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ channel factory
3
+ """
4
+ from common import const
5
+
6
+ def create_channel(channel_type):
7
+ """
8
+ create a channel instance
9
+ :param channel_type: channel type code
10
+ :return: channel instance
11
+ """
12
+ if channel_type == "wx":
13
+ from channel.wechat.wechat_channel import WechatChannel
14
+
15
+ return WechatChannel()
16
+ elif channel_type == "wxy":
17
+ from channel.wechat.wechaty_channel import WechatyChannel
18
+
19
+ return WechatyChannel()
20
+ elif channel_type == "terminal":
21
+ from channel.terminal.terminal_channel import TerminalChannel
22
+
23
+ return TerminalChannel()
24
+ elif channel_type == "wechatmp":
25
+ from channel.wechatmp.wechatmp_channel import WechatMPChannel
26
+
27
+ return WechatMPChannel(passive_reply=True)
28
+ elif channel_type == "wechatmp_service":
29
+ from channel.wechatmp.wechatmp_channel import WechatMPChannel
30
+
31
+ return WechatMPChannel(passive_reply=False)
32
+ elif channel_type == "wechatcom_app":
33
+ from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
34
+
35
+ return WechatComAppChannel()
36
+ elif channel_type == "wework":
37
+ from channel.wework.wework_channel import WeworkChannel
38
+ return WeworkChannel()
39
+
40
+ elif channel_type == const.FEISHU:
41
+ from channel.feishu.feishu_channel import FeiShuChanel
42
+ return FeiShuChanel()
43
+
44
+ raise RuntimeError
channel/chat_channel.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import threading
4
+ import time
5
+ from asyncio import CancelledError
6
+ from concurrent.futures import Future, ThreadPoolExecutor
7
+
8
+ from bridge.context import *
9
+ from bridge.reply import *
10
+ from channel.channel import Channel
11
+ from common.dequeue import Dequeue
12
+ from common import memory
13
+ from plugins import *
14
+
15
+ try:
16
+ from voice.audio_convert import any_to_wav
17
+ except Exception as e:
18
+ pass
19
+
20
+
21
+ # 抽象类, 它包含了与消息通道无关的通用处理逻辑
22
+ class ChatChannel(Channel):
23
+ name = None # 登录的用户名
24
+ user_id = None # 登录的用户id
25
+ futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
26
+ sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
27
+ lock = threading.Lock() # 用于控制对sessions的访问
28
+ handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
29
+
30
+ def __init__(self):
31
+ _thread = threading.Thread(target=self.consume)
32
+ _thread.setDaemon(True)
33
+ _thread.start()
34
+
35
+ # 根据消息构造context,消息内容相关的触发项写在这里
36
+ def _compose_context(self, ctype: ContextType, content, **kwargs):
37
+ context = Context(ctype, content)
38
+ context.kwargs = kwargs
39
+ # context首次传入时,origin_ctype是None,
40
+ # 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
41
+ # origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
42
+ if "origin_ctype" not in context:
43
+ context["origin_ctype"] = ctype
44
+ # context首次传入时,receiver是None,根据类型设置receiver
45
+ first_in = "receiver" not in context
46
+ # 群名匹配过程,设置session_id和receiver
47
+ if first_in: # context首次传入时,receiver是None,根据类型设置receiver
48
+ config = conf()
49
+ cmsg = context["msg"]
50
+ user_data = conf().get_user_data(cmsg.from_user_id)
51
+ context["openai_api_key"] = user_data.get("openai_api_key")
52
+ context["gpt_model"] = user_data.get("gpt_model")
53
+ if context.get("isgroup", False):
54
+ group_name = cmsg.other_user_nickname
55
+ group_id = cmsg.other_user_id
56
+
57
+ group_name_white_list = config.get("group_name_white_list", [])
58
+ group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
59
+ if any(
60
+ [
61
+ group_name in group_name_white_list,
62
+ "ALL_GROUP" in group_name_white_list,
63
+ check_contain(group_name, group_name_keyword_white_list),
64
+ ]
65
+ ):
66
+ group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
67
+ session_id = cmsg.actual_user_id
68
+ if any(
69
+ [
70
+ group_name in group_chat_in_one_session,
71
+ "ALL_GROUP" in group_chat_in_one_session,
72
+ ]
73
+ ):
74
+ session_id = group_id
75
+ else:
76
+ return None
77
+ context["session_id"] = session_id
78
+ context["receiver"] = group_id
79
+ else:
80
+ context["session_id"] = cmsg.other_user_id
81
+ context["receiver"] = cmsg.other_user_id
82
+ e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
83
+ context = e_context["context"]
84
+ if e_context.is_pass() or context is None:
85
+ return context
86
+ if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
87
+ logger.debug("[WX]self message skipped")
88
+ return None
89
+
90
+ # 消息内容匹配过程,并处理content
91
+ if ctype == ContextType.TEXT:
92
+ if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
93
+ logger.debug(content)
94
+ logger.debug("[WX]reference query skipped")
95
+ return None
96
+
97
+ nick_name_black_list = conf().get("nick_name_black_list", [])
98
+ if context.get("isgroup", False): # 群聊
99
+ # 校验关键字
100
+ match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
101
+ match_contain = check_contain(content, conf().get("group_chat_keyword"))
102
+ flag = False
103
+ if context["msg"].to_user_id != context["msg"].actual_user_id:
104
+ if match_prefix is not None or match_contain is not None:
105
+ flag = True
106
+ if match_prefix:
107
+ content = content.replace(match_prefix, "", 1).strip()
108
+ if context["msg"].is_at:
109
+ nick_name = context["msg"].actual_user_nickname
110
+ if nick_name and nick_name in nick_name_black_list:
111
+ # 黑名单过滤
112
+ logger.warning(f"[WX] Nickname {nick_name} in In BlackList, ignore")
113
+ return None
114
+
115
+ logger.info("[WX]receive group at")
116
+ if not conf().get("group_at_off", False):
117
+ flag = True
118
+ pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
119
+ subtract_res = re.sub(pattern, r"", content)
120
+ if isinstance(context["msg"].at_list, list):
121
+ for at in context["msg"].at_list:
122
+ pattern = f"@{re.escape(at)}(\u2005|\u0020)"
123
+ subtract_res = re.sub(pattern, r"", subtract_res)
124
+ if subtract_res == content and context["msg"].self_display_name:
125
+ # 前缀移除后没有变化,使用群昵称再次移除
126
+ pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
127
+ subtract_res = re.sub(pattern, r"", content)
128
+ content = subtract_res
129
+ if not flag:
130
+ if context["origin_ctype"] == ContextType.VOICE:
131
+ logger.info("[WX]receive group voice, but checkprefix didn't match")
132
+ return None
133
+ else: # 单聊
134
+ nick_name = context["msg"].from_user_nickname
135
+ if nick_name and nick_name in nick_name_black_list:
136
+ # 黑名单过滤
137
+ logger.warning(f"[WX] Nickname '{nick_name}' in In BlackList, ignore")
138
+ return None
139
+
140
+ match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
141
+ if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
142
+ content = content.replace(match_prefix, "", 1).strip()
143
+ elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
144
+ pass
145
+ else:
146
+ return None
147
+ content = content.strip()
148
+ img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
149
+ if img_match_prefix:
150
+ content = content.replace(img_match_prefix, "", 1)
151
+ context.type = ContextType.IMAGE_CREATE
152
+ else:
153
+ context.type = ContextType.TEXT
154
+ context.content = content.strip()
155
+ if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
156
+ context["desire_rtype"] = ReplyType.VOICE
157
+ elif context.type == ContextType.VOICE:
158
+ if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
159
+ context["desire_rtype"] = ReplyType.VOICE
160
+
161
+ return context
162
+
163
+ def _handle(self, context: Context):
164
+ if context is None or not context.content:
165
+ return
166
+ logger.debug("[WX] ready to handle context: {}".format(context))
167
+ # reply的构建步骤
168
+ reply = self._generate_reply(context)
169
+
170
+ logger.debug("[WX] ready to decorate reply: {}".format(reply))
171
+ # reply的包装步骤
172
+ reply = self._decorate_reply(context, reply)
173
+
174
+ # reply的发送步骤
175
+ self._send_reply(context, reply)
176
+
177
+ def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
178
+ e_context = PluginManager().emit_event(
179
+ EventContext(
180
+ Event.ON_HANDLE_CONTEXT,
181
+ {"channel": self, "context": context, "reply": reply},
182
+ )
183
+ )
184
+ reply = e_context["reply"]
185
+ if not e_context.is_pass():
186
+ logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
187
+ if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
188
+ context["channel"] = e_context["channel"]
189
+ reply = super().build_reply_content(context.content, context)
190
+ elif context.type == ContextType.VOICE: # 语音消息
191
+ cmsg = context["msg"]
192
+ cmsg.prepare()
193
+ file_path = context.content
194
+ wav_path = os.path.splitext(file_path)[0] + ".wav"
195
+ try:
196
+ any_to_wav(file_path, wav_path)
197
+ except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
198
+ logger.warning("[WX]any to wav error, use raw path. " + str(e))
199
+ wav_path = file_path
200
+ # 语音识别
201
+ reply = super().build_voice_to_text(wav_path)
202
+ # 删除临时文件
203
+ try:
204
+ os.remove(file_path)
205
+ if wav_path != file_path:
206
+ os.remove(wav_path)
207
+ except Exception as e:
208
+ pass
209
+ # logger.warning("[WX]delete temp file error: " + str(e))
210
+
211
+ if reply.type == ReplyType.TEXT:
212
+ new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
213
+ if new_context:
214
+ reply = self._generate_reply(new_context)
215
+ else:
216
+ return
217
+ elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
218
+ memory.USER_IMAGE_CACHE[context["session_id"]] = {
219
+ "path": context.content,
220
+ "msg": context.get("msg")
221
+ }
222
+ elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
223
+ pass
224
+ elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
225
+ pass
226
+ else:
227
+ logger.warning("[WX] unknown context type: {}".format(context.type))
228
+ return
229
+ return reply
230
+
231
+ def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
232
+ if reply and reply.type:
233
+ e_context = PluginManager().emit_event(
234
+ EventContext(
235
+ Event.ON_DECORATE_REPLY,
236
+ {"channel": self, "context": context, "reply": reply},
237
+ )
238
+ )
239
+ reply = e_context["reply"]
240
+ desire_rtype = context.get("desire_rtype")
241
+ if not e_context.is_pass() and reply and reply.type:
242
+ if reply.type in self.NOT_SUPPORT_REPLYTYPE:
243
+ logger.error("[WX]reply type not support: " + str(reply.type))
244
+ reply.type = ReplyType.ERROR
245
+ reply.content = "不支持发送的消息类型: " + str(reply.type)
246
+
247
+ if reply.type == ReplyType.TEXT:
248
+ reply_text = reply.content
249
+ if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
250
+ reply = super().build_text_to_voice(reply.content)
251
+ return self._decorate_reply(context, reply)
252
+ if context.get("isgroup", False):
253
+ if not context.get("no_need_at", False):
254
+ reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
255
+ reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "")
256
+ else:
257
+ reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "")
258
+ reply.content = reply_text
259
+ elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
260
+ reply.content = "[" + str(reply.type) + "]\n" + reply.content
261
+ elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
262
+ pass
263
+ else:
264
+ logger.error("[WX] unknown reply type: {}".format(reply.type))
265
+ return
266
+ if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
267
+ logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
268
+ return reply
269
+
270
+ def _send_reply(self, context: Context, reply: Reply):
271
+ if reply and reply.type:
272
+ e_context = PluginManager().emit_event(
273
+ EventContext(
274
+ Event.ON_SEND_REPLY,
275
+ {"channel": self, "context": context, "reply": reply},
276
+ )
277
+ )
278
+ reply = e_context["reply"]
279
+ if not e_context.is_pass() and reply and reply.type:
280
+ logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
281
+ self._send(reply, context)
282
+
283
+ def _send(self, reply: Reply, context: Context, retry_cnt=0):
284
+ try:
285
+ self.send(reply, context)
286
+ except Exception as e:
287
+ logger.error("[WX] sendMsg error: {}".format(str(e)))
288
+ if isinstance(e, NotImplementedError):
289
+ return
290
+ logger.exception(e)
291
+ if retry_cnt < 2:
292
+ time.sleep(3 + 3 * retry_cnt)
293
+ self._send(reply, context, retry_cnt + 1)
294
+
295
+ def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
296
+ logger.debug("Worker return success, session_id = {}".format(session_id))
297
+
298
+ def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
299
+ logger.exception("Worker return exception: {}".format(exception))
300
+
301
+ def _thread_pool_callback(self, session_id, **kwargs):
302
+ def func(worker: Future):
303
+ try:
304
+ worker_exception = worker.exception()
305
+ if worker_exception:
306
+ self._fail_callback(session_id, exception=worker_exception, **kwargs)
307
+ else:
308
+ self._success_callback(session_id, **kwargs)
309
+ except CancelledError as e:
310
+ logger.info("Worker cancelled, session_id = {}".format(session_id))
311
+ except Exception as e:
312
+ logger.exception("Worker raise exception: {}".format(e))
313
+ with self.lock:
314
+ self.sessions[session_id][1].release()
315
+
316
+ return func
317
+
318
+ def produce(self, context: Context):
319
+ session_id = context["session_id"]
320
+ with self.lock:
321
+ if session_id not in self.sessions:
322
+ self.sessions[session_id] = [
323
+ Dequeue(),
324
+ threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
325
+ ]
326
+ if context.type == ContextType.TEXT and context.content.startswith("#"):
327
+ self.sessions[session_id][0].putleft(context) # 优先处理管理命令
328
+ else:
329
+ self.sessions[session_id][0].put(context)
330
+
331
+ # 消费者函数,单独线程,用于从消息队列中取出消息并处理
332
+ def consume(self):
333
+ while True:
334
+ with self.lock:
335
+ session_ids = list(self.sessions.keys())
336
+ for session_id in session_ids:
337
+ context_queue, semaphore = self.sessions[session_id]
338
+ if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
339
+ if not context_queue.empty():
340
+ context = context_queue.get()
341
+ logger.debug("[WX] consume context: {}".format(context))
342
+ future: Future = self.handler_pool.submit(self._handle, context)
343
+ future.add_done_callback(self._thread_pool_callback(session_id, context=context))
344
+ if session_id not in self.futures:
345
+ self.futures[session_id] = []
346
+ self.futures[session_id].append(future)
347
+ elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
348
+ self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
349
+ assert len(self.futures[session_id]) == 0, "thread pool error"
350
+ del self.sessions[session_id]
351
+ else:
352
+ semaphore.release()
353
+ time.sleep(0.1)
354
+
355
+ # 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
356
+ def cancel_session(self, session_id):
357
+ with self.lock:
358
+ if session_id in self.sessions:
359
+ for future in self.futures[session_id]:
360
+ future.cancel()
361
+ cnt = self.sessions[session_id][0].qsize()
362
+ if cnt > 0:
363
+ logger.info("Cancel {} messages in session {}".format(cnt, session_id))
364
+ self.sessions[session_id][0] = Dequeue()
365
+
366
+ def cancel_all_session(self):
367
+ with self.lock:
368
+ for session_id in self.sessions:
369
+ for future in self.futures[session_id]:
370
+ future.cancel()
371
+ cnt = self.sessions[session_id][0].qsize()
372
+ if cnt > 0:
373
+ logger.info("Cancel {} messages in session {}".format(cnt, session_id))
374
+ self.sessions[session_id][0] = Dequeue()
375
+
376
+
377
+ def check_prefix(content, prefix_list):
378
+ if not prefix_list:
379
+ return None
380
+ for prefix in prefix_list:
381
+ if content.startswith(prefix):
382
+ return prefix
383
+ return None
384
+
385
+
386
+ def check_contain(content, keyword_list):
387
+ if not keyword_list:
388
+ return None
389
+ for ky in keyword_list:
390
+ if content.find(ky) != -1:
391
+ return True
392
+ return None
channel/chat_message.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
3
+
4
+ 填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
5
+
6
+ ChatMessage
7
+ msg_id: 消息id (必填)
8
+ create_time: 消息创建时间
9
+
10
+ ctype: 消息类型 : ContextType (必填)
11
+ content: 消息内容, 如果是声音/图片,这里是文件路径 (必填)
12
+
13
+ from_user_id: 发送者id (必填)
14
+ from_user_nickname: 发送者昵称
15
+ to_user_id: 接收者id (必填)
16
+ to_user_nickname: 接收者昵称
17
+
18
+ other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id (必填)
19
+ other_user_nickname: 同上
20
+
21
+ is_group: 是否是群消息 (群聊必填)
22
+ is_at: 是否被at
23
+
24
+ - (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
25
+ actual_user_id: 实际发送者id (群聊必填)
26
+ actual_user_nickname:实际发送者昵称
27
+ self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
28
+
29
+ _prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
30
+ _prepared: 是否已经调用过准备函数
31
+ _rawmsg: 原始消息对象
32
+
33
+ """
34
+
35
+
36
+ class ChatMessage(object):
37
+ msg_id = None
38
+ create_time = None
39
+
40
+ ctype = None
41
+ content = None
42
+
43
+ from_user_id = None
44
+ from_user_nickname = None
45
+ to_user_id = None
46
+ to_user_nickname = None
47
+ other_user_id = None
48
+ other_user_nickname = None
49
+ my_msg = False
50
+ self_display_name = None
51
+
52
+ is_group = False
53
+ is_at = False
54
+ actual_user_id = None
55
+ actual_user_nickname = None
56
+ at_list = None
57
+
58
+ _prepare_fn = None
59
+ _prepared = False
60
+ _rawmsg = None
61
+
62
+ def __init__(self, _rawmsg):
63
+ self._rawmsg = _rawmsg
64
+
65
+ def prepare(self):
66
+ if self._prepare_fn and not self._prepared:
67
+ self._prepared = True
68
+ self._prepare_fn()
69
+
70
+ def __str__(self):
71
+ return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}, at_list={}".format(
72
+ self.msg_id,
73
+ self.create_time,
74
+ self.ctype,
75
+ self.content,
76
+ self.from_user_id,
77
+ self.from_user_nickname,
78
+ self.to_user_id,
79
+ self.to_user_nickname,
80
+ self.other_user_id,
81
+ self.other_user_nickname,
82
+ self.is_group,
83
+ self.is_at,
84
+ self.actual_user_id,
85
+ self.actual_user_nickname,
86
+ self.at_list
87
+ )
channel/feishu/feishu_channel.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 飞书通道接入
3
+
4
+ @author Saboteur7
5
+ @Date 2023/11/19
6
+ """
7
+
8
+ # -*- coding=utf-8 -*-
9
+ import uuid
10
+
11
+ import requests
12
+ import web
13
+ from channel.feishu.feishu_message import FeishuMessage
14
+ from bridge.context import Context
15
+ from bridge.reply import Reply, ReplyType
16
+ from common.log import logger
17
+ from common.singleton import singleton
18
+ from config import conf
19
+ from common.expired_dict import ExpiredDict
20
+ from bridge.context import ContextType
21
+ from channel.chat_channel import ChatChannel, check_prefix
22
+ from common import utils
23
+ import json
24
+ import os
25
+
26
+ URL_VERIFICATION = "url_verification"
27
+
28
+
29
+ @singleton
30
+ class FeiShuChanel(ChatChannel):
31
+ feishu_app_id = conf().get('feishu_app_id')
32
+ feishu_app_secret = conf().get('feishu_app_secret')
33
+ feishu_token = conf().get('feishu_token')
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+ # 历史消息id暂存,用于幂等控制
38
+ self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
39
+ logger.info("[FeiShu] app_id={}, app_secret={} verification_token={}".format(
40
+ self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
41
+ # 无需群校验和前缀
42
+ conf()["group_name_white_list"] = ["ALL_GROUP"]
43
+ conf()["single_chat_prefix"] = []
44
+
45
+ def startup(self):
46
+ urls = (
47
+ '/', 'channel.feishu.feishu_channel.FeishuController'
48
+ )
49
+ app = web.application(urls, globals(), autoreload=False)
50
+ port = conf().get("feishu_port", 9891)
51
+ web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
52
+
53
+ def send(self, reply: Reply, context: Context):
54
+ msg = context["msg"]
55
+ is_group = context["isgroup"]
56
+ headers = {
57
+ "Authorization": "Bearer " + msg.access_token,
58
+ "Content-Type": "application/json",
59
+ }
60
+ msg_type = "text"
61
+ logger.info(f"[FeiShu] start send reply message, type={context.type}, content={reply.content}")
62
+ reply_content = reply.content
63
+ content_key = "text"
64
+ if reply.type == ReplyType.IMAGE_URL:
65
+ # 图片上传
66
+ reply_content = self._upload_image_url(reply.content, msg.access_token)
67
+ if not reply_content:
68
+ logger.warning("[FeiShu] upload file failed")
69
+ return
70
+ msg_type = "image"
71
+ content_key = "image_key"
72
+ if is_group:
73
+ # 群聊中直接回复
74
+ url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
75
+ data = {
76
+ "msg_type": msg_type,
77
+ "content": json.dumps({content_key: reply_content})
78
+ }
79
+ res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
80
+ else:
81
+ url = "https://open.feishu.cn/open-apis/im/v1/messages"
82
+ params = {"receive_id_type": context.get("receive_id_type")}
83
+ data = {
84
+ "receive_id": context.get("receiver"),
85
+ "msg_type": msg_type,
86
+ "content": json.dumps({content_key: reply_content})
87
+ }
88
+ res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
89
+ res = res.json()
90
+ if res.get("code") == 0:
91
+ logger.info(f"[FeiShu] send message success")
92
+ else:
93
+ logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}")
94
+
95
+
96
+ def fetch_access_token(self) -> str:
97
+ url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/"
98
+ headers = {
99
+ "Content-Type": "application/json"
100
+ }
101
+ req_body = {
102
+ "app_id": self.feishu_app_id,
103
+ "app_secret": self.feishu_app_secret
104
+ }
105
+ data = bytes(json.dumps(req_body), encoding='utf8')
106
+ response = requests.post(url=url, data=data, headers=headers)
107
+ if response.status_code == 200:
108
+ res = response.json()
109
+ if res.get("code") != 0:
110
+ logger.error(f"[FeiShu] get tenant_access_token error, code={res.get('code')}, msg={res.get('msg')}")
111
+ return ""
112
+ else:
113
+ return res.get("tenant_access_token")
114
+ else:
115
+ logger.error(f"[FeiShu] fetch token error, res={response}")
116
+
117
+
118
+ def _upload_image_url(self, img_url, access_token):
119
+ logger.debug(f"[WX] start download image, img_url={img_url}")
120
+ response = requests.get(img_url)
121
+ suffix = utils.get_path_suffix(img_url)
122
+ temp_name = str(uuid.uuid4()) + "." + suffix
123
+ if response.status_code == 200:
124
+ # 将图片内容保存为临时文件
125
+ with open(temp_name, "wb") as file:
126
+ file.write(response.content)
127
+
128
+ # upload
129
+ upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
130
+ data = {
131
+ 'image_type': 'message'
132
+ }
133
+ headers = {
134
+ 'Authorization': f'Bearer {access_token}',
135
+ }
136
+ with open(temp_name, "rb") as file:
137
+ upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
138
+ logger.info(f"[FeiShu] upload file, res={upload_response.content}")
139
+ os.remove(temp_name)
140
+ return upload_response.json().get("data").get("image_key")
141
+
142
+
143
+
144
+ class FeishuController:
145
+ # 类常量
146
+ FAILED_MSG = '{"success": false}'
147
+ SUCCESS_MSG = '{"success": true}'
148
+ MESSAGE_RECEIVE_TYPE = "im.message.receive_v1"
149
+
150
+ def GET(self):
151
+ return "Feishu service start success!"
152
+
153
+ def POST(self):
154
+ try:
155
+ channel = FeiShuChanel()
156
+
157
+ request = json.loads(web.data().decode("utf-8"))
158
+ logger.debug(f"[FeiShu] receive request: {request}")
159
+
160
+ # 1.事件订阅回调验证
161
+ if request.get("type") == URL_VERIFICATION:
162
+ varify_res = {"challenge": request.get("challenge")}
163
+ return json.dumps(varify_res)
164
+
165
+ # 2.消息接收处理
166
+ # token 校验
167
+ header = request.get("header")
168
+ if not header or header.get("token") != channel.feishu_token:
169
+ return self.FAILED_MSG
170
+
171
+ # 处理消息事件
172
+ event = request.get("event")
173
+ if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
174
+ if not event.get("message") or not event.get("sender"):
175
+ logger.warning(f"[FeiShu] invalid message, msg={request}")
176
+ return self.FAILED_MSG
177
+ msg = event.get("message")
178
+
179
+ # 幂等判断
180
+ if channel.receivedMsgs.get(msg.get("message_id")):
181
+ logger.warning(f"[FeiShu] repeat msg filtered, event_id={header.get('event_id')}")
182
+ return self.SUCCESS_MSG
183
+ channel.receivedMsgs[msg.get("message_id")] = True
184
+
185
+ is_group = False
186
+ chat_type = msg.get("chat_type")
187
+ if chat_type == "group":
188
+ if not msg.get("mentions") and msg.get("message_type") == "text":
189
+ # 群聊中未@不响应
190
+ return self.SUCCESS_MSG
191
+ if msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get("message_type") == "text":
192
+ # 不是@机器人,不响应
193
+ return self.SUCCESS_MSG
194
+ # 群聊
195
+ is_group = True
196
+ receive_id_type = "chat_id"
197
+ elif chat_type == "p2p":
198
+ receive_id_type = "open_id"
199
+ else:
200
+ logger.warning("[FeiShu] message ignore")
201
+ return self.SUCCESS_MSG
202
+ # 构造飞书消息对象
203
+ feishu_msg = FeishuMessage(event, is_group=is_group, access_token=channel.fetch_access_token())
204
+ if not feishu_msg:
205
+ return self.SUCCESS_MSG
206
+
207
+ context = self._compose_context(
208
+ feishu_msg.ctype,
209
+ feishu_msg.content,
210
+ isgroup=is_group,
211
+ msg=feishu_msg,
212
+ receive_id_type=receive_id_type,
213
+ no_need_at=True
214
+ )
215
+ if context:
216
+ channel.produce(context)
217
+ logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
218
+ return self.SUCCESS_MSG
219
+
220
+ except Exception as e:
221
+ logger.error(e)
222
+ return self.FAILED_MSG
223
+
224
+ def _compose_context(self, ctype: ContextType, content, **kwargs):
225
+ context = Context(ctype, content)
226
+ context.kwargs = kwargs
227
+ if "origin_ctype" not in context:
228
+ context["origin_ctype"] = ctype
229
+
230
+ cmsg = context["msg"]
231
+ context["session_id"] = cmsg.from_user_id
232
+ context["receiver"] = cmsg.other_user_id
233
+
234
+ if ctype == ContextType.TEXT:
235
+ # 1.文本请求
236
+ # 图片生成处理
237
+ img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
238
+ if img_match_prefix:
239
+ content = content.replace(img_match_prefix, "", 1)
240
+ context.type = ContextType.IMAGE_CREATE
241
+ else:
242
+ context.type = ContextType.TEXT
243
+ context.content = content.strip()
244
+
245
+ elif context.type == ContextType.VOICE:
246
+ # 2.语音请求
247
+ if "desire_rtype" not in context and conf().get("voice_reply_voice"):
248
+ context["desire_rtype"] = ReplyType.VOICE
249
+
250
+ return context
channel/feishu/feishu_message.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bridge.context import ContextType
2
+ from channel.chat_message import ChatMessage
3
+ import json
4
+ import requests
5
+ from common.log import logger
6
+ from common.tmp_dir import TmpDir
7
+ from common import utils
8
+
9
+
10
+ class FeishuMessage(ChatMessage):
11
+ def __init__(self, event: dict, is_group=False, access_token=None):
12
+ super().__init__(event)
13
+ msg = event.get("message")
14
+ sender = event.get("sender")
15
+ self.access_token = access_token
16
+ self.msg_id = msg.get("message_id")
17
+ self.create_time = msg.get("create_time")
18
+ self.is_group = is_group
19
+ msg_type = msg.get("message_type")
20
+
21
+ if msg_type == "text":
22
+ self.ctype = ContextType.TEXT
23
+ content = json.loads(msg.get('content'))
24
+ self.content = content.get("text").strip()
25
+ elif msg_type == "file":
26
+ self.ctype = ContextType.FILE
27
+ content = json.loads(msg.get("content"))
28
+ file_key = content.get("file_key")
29
+ file_name = content.get("file_name")
30
+
31
+ self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
32
+
33
+ def _download_file():
34
+ # 如果响应状态码是200,则将响应内容写入本地文件
35
+ url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
36
+ headers = {
37
+ "Authorization": "Bearer " + access_token,
38
+ }
39
+ params = {
40
+ "type": "file"
41
+ }
42
+ response = requests.get(url=url, headers=headers, params=params)
43
+ if response.status_code == 200:
44
+ with open(self.content, "wb") as f:
45
+ f.write(response.content)
46
+ else:
47
+ logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
48
+ self._prepare_fn = _download_file
49
+
50
+ # elif msg.type == "voice":
51
+ # self.ctype = ContextType.VOICE
52
+ # self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
53
+ #
54
+ # def download_voice():
55
+ # # 如果响应状态码是200,则将响应内容写入本地文件
56
+ # response = client.media.download(msg.media_id)
57
+ # if response.status_code == 200:
58
+ # with open(self.content, "wb") as f:
59
+ # f.write(response.content)
60
+ # else:
61
+ # logger.info(f"[wechatcom] Failed to download voice file, {response.content}")
62
+ #
63
+ # self._prepare_fn = download_voice
64
+ # elif msg.type == "image":
65
+ # self.ctype = ContextType.IMAGE
66
+ # self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
67
+ #
68
+ # def download_image():
69
+ # # 如果响应状态码是200,则将响应内容写入本地文件
70
+ # response = client.media.download(msg.media_id)
71
+ # if response.status_code == 200:
72
+ # with open(self.content, "wb") as f:
73
+ # f.write(response.content)
74
+ # else:
75
+ # logger.info(f"[wechatcom] Failed to download image file, {response.content}")
76
+ #
77
+ # self._prepare_fn = download_image
78
+ else:
79
+ raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
80
+
81
+ self.from_user_id = sender.get("sender_id").get("open_id")
82
+ self.to_user_id = event.get("app_id")
83
+ if is_group:
84
+ # 群聊
85
+ self.other_user_id = msg.get("chat_id")
86
+ self.actual_user_id = self.from_user_id
87
+ self.content = self.content.replace("@_user_1", "").strip()
88
+ self.actual_user_nickname = ""
89
+ else:
90
+ # 私聊
91
+ self.other_user_id = self.from_user_id
92
+ self.actual_user_id = self.from_user_id
channel/terminal/terminal_channel.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from bridge.context import *
4
+ from bridge.reply import Reply, ReplyType
5
+ from channel.chat_channel import ChatChannel, check_prefix
6
+ from channel.chat_message import ChatMessage
7
+ from common.log import logger
8
+ from config import conf
9
+
10
+
11
+ class TerminalMessage(ChatMessage):
12
+ def __init__(
13
+ self,
14
+ msg_id,
15
+ content,
16
+ ctype=ContextType.TEXT,
17
+ from_user_id="User",
18
+ to_user_id="Chatgpt",
19
+ other_user_id="Chatgpt",
20
+ ):
21
+ self.msg_id = msg_id
22
+ self.ctype = ctype
23
+ self.content = content
24
+ self.from_user_id = from_user_id
25
+ self.to_user_id = to_user_id
26
+ self.other_user_id = other_user_id
27
+
28
+
29
+ class TerminalChannel(ChatChannel):
30
+ NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
31
+
32
+ def send(self, reply: Reply, context: Context):
33
+ print("\nBot:")
34
+ if reply.type == ReplyType.IMAGE:
35
+ from PIL import Image
36
+
37
+ image_storage = reply.content
38
+ image_storage.seek(0)
39
+ img = Image.open(image_storage)
40
+ print("<IMAGE>")
41
+ img.show()
42
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
43
+ import io
44
+
45
+ import requests
46
+ from PIL import Image
47
+
48
+ img_url = reply.content
49
+ pic_res = requests.get(img_url, stream=True)
50
+ image_storage = io.BytesIO()
51
+ for block in pic_res.iter_content(1024):
52
+ image_storage.write(block)
53
+ image_storage.seek(0)
54
+ img = Image.open(image_storage)
55
+ print(img_url)
56
+ img.show()
57
+ else:
58
+ print(reply.content)
59
+ print("\nUser:", end="")
60
+ sys.stdout.flush()
61
+ return
62
+
63
+ def startup(self):
64
+ context = Context()
65
+ logger.setLevel("WARN")
66
+ print("\nPlease input your question:\nUser:", end="")
67
+ sys.stdout.flush()
68
+ msg_id = 0
69
+ while True:
70
+ try:
71
+ prompt = self.get_input()
72
+ except KeyboardInterrupt:
73
+ print("\nExiting...")
74
+ sys.exit()
75
+ msg_id += 1
76
+ trigger_prefixs = conf().get("single_chat_prefix", [""])
77
+ if check_prefix(prompt, trigger_prefixs) is None:
78
+ prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
79
+
80
+ context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
81
+ if context:
82
+ self.produce(context)
83
+ else:
84
+ raise Exception("context is None")
85
+
86
+ def get_input(self):
87
+ """
88
+ Multi-line input function
89
+ """
90
+ sys.stdout.flush()
91
+ line = input()
92
+ return line
channel/wechat/wechat_channel.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ """
4
+ wechat channel
5
+ """
6
+
7
+ import io
8
+ import json
9
+ import os
10
+ import threading
11
+ import time
12
+
13
+ import requests
14
+
15
+ from bridge.context import *
16
+ from bridge.reply import *
17
+ from channel.chat_channel import ChatChannel
18
+ from channel.wechat.wechat_message import *
19
+ from common.expired_dict import ExpiredDict
20
+ from common.log import logger
21
+ from common.singleton import singleton
22
+ from common.time_check import time_checker
23
+ from config import conf, get_appdata_dir
24
+ from lib import itchat
25
+ from lib.itchat.content import *
26
+
27
+
28
+ @itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING])
29
+ def handler_single_msg(msg):
30
+ try:
31
+ cmsg = WechatMessage(msg, False)
32
+ except NotImplementedError as e:
33
+ logger.debug("[WX]single message {} skipped: {}".format(msg["MsgId"], e))
34
+ return None
35
+ WechatChannel().handle_single(cmsg)
36
+ return None
37
+
38
+
39
+ @itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING], isGroupChat=True)
40
+ def handler_group_msg(msg):
41
+ try:
42
+ cmsg = WechatMessage(msg, True)
43
+ except NotImplementedError as e:
44
+ logger.debug("[WX]group message {} skipped: {}".format(msg["MsgId"], e))
45
+ return None
46
+ WechatChannel().handle_group(cmsg)
47
+ return None
48
+
49
+
50
+ def _check(func):
51
+ def wrapper(self, cmsg: ChatMessage):
52
+ msgId = cmsg.msg_id
53
+ if msgId in self.receivedMsgs:
54
+ logger.info("Wechat message {} already received, ignore".format(msgId))
55
+ return
56
+ self.receivedMsgs[msgId] = True
57
+ create_time = cmsg.create_time # 消息时间戳
58
+ if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
59
+ logger.debug("[WX]history message {} skipped".format(msgId))
60
+ return
61
+ if cmsg.my_msg and not cmsg.is_group:
62
+ logger.debug("[WX]my message {} skipped".format(msgId))
63
+ return
64
+ return func(self, cmsg)
65
+
66
+ return wrapper
67
+
68
+
69
+ # 可用的二维码生成接口
70
+ # https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
71
+ # https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
72
+ def qrCallback(uuid, status, qrcode):
73
+ # logger.debug("qrCallback: {} {}".format(uuid,status))
74
+ if status == "0":
75
+ try:
76
+ from PIL import Image
77
+
78
+ img = Image.open(io.BytesIO(qrcode))
79
+ _thread = threading.Thread(target=img.show, args=("QRCode",))
80
+ _thread.setDaemon(True)
81
+ _thread.start()
82
+ except Exception as e:
83
+ pass
84
+
85
+ import qrcode
86
+
87
+ url = f"https://login.weixin.qq.com/l/{uuid}"
88
+
89
+ qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
90
+ qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
91
+ qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
92
+ qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
93
+ print("You can also scan QRCode in any website below:")
94
+ print(qr_api3)
95
+ print(qr_api4)
96
+ print(qr_api2)
97
+ print(qr_api1)
98
+
99
+ qr = qrcode.QRCode(border=1)
100
+ qr.add_data(url)
101
+ qr.make(fit=True)
102
+ qr.print_ascii(invert=True)
103
+
104
+
105
+ @singleton
106
+ class WechatChannel(ChatChannel):
107
+ NOT_SUPPORT_REPLYTYPE = []
108
+
109
+ def __init__(self):
110
+ super().__init__()
111
+ self.receivedMsgs = ExpiredDict(60 * 60)
112
+
113
+ def startup(self):
114
+ itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
115
+ # login by scan QRCode
116
+ hotReload = conf().get("hot_reload", False)
117
+ status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
118
+ itchat.auto_login(
119
+ enableCmdQR=2,
120
+ hotReload=hotReload,
121
+ statusStorageDir=status_path,
122
+ qrCallback=qrCallback,
123
+ )
124
+ self.user_id = itchat.instance.storageClass.userName
125
+ self.name = itchat.instance.storageClass.nickName
126
+ logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
127
+ # start message listener
128
+ itchat.run()
129
+
130
+ # handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复
131
+ # Context包含了消息的所有信息,包括以下属性
132
+ # type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
133
+ # content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
134
+ # kwargs 附加参数字典,包含以下的key:
135
+ # session_id: 会话id
136
+ # isgroup: 是否是群聊
137
+ # receiver: 需要回复的对象
138
+ # msg: ChatMessage消息对象
139
+ # origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
140
+ # desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
141
+
142
+ @time_checker
143
+ @_check
144
+ def handle_single(self, cmsg: ChatMessage):
145
+ # filter system message
146
+ if cmsg.other_user_id in ["weixin"]:
147
+ return
148
+ if cmsg.ctype == ContextType.VOICE:
149
+ if conf().get("speech_recognition") != True:
150
+ return
151
+ logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
152
+ elif cmsg.ctype == ContextType.IMAGE:
153
+ logger.debug("[WX]receive image msg: {}".format(cmsg.content))
154
+ elif cmsg.ctype == ContextType.PATPAT:
155
+ logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
156
+ elif cmsg.ctype == ContextType.TEXT:
157
+ logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
158
+ else:
159
+ logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
160
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
161
+ if context:
162
+ self.produce(context)
163
+
164
+ @time_checker
165
+ @_check
166
+ def handle_group(self, cmsg: ChatMessage):
167
+ if cmsg.ctype == ContextType.VOICE:
168
+ if conf().get("group_speech_recognition") != True:
169
+ return
170
+ logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
171
+ elif cmsg.ctype == ContextType.IMAGE:
172
+ logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
173
+ elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.ACCEPT_FRIEND, ContextType.EXIT_GROUP]:
174
+ logger.debug("[WX]receive note msg: {}".format(cmsg.content))
175
+ elif cmsg.ctype == ContextType.TEXT:
176
+ # logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
177
+ pass
178
+ elif cmsg.ctype == ContextType.FILE:
179
+ logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}")
180
+ else:
181
+ logger.debug("[WX]receive group msg: {}".format(cmsg.content))
182
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
183
+ if context:
184
+ self.produce(context)
185
+
186
+ # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
187
+ def send(self, reply: Reply, context: Context):
188
+ receiver = context["receiver"]
189
+ if reply.type == ReplyType.TEXT:
190
+ itchat.send(reply.content, toUserName=receiver)
191
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
192
+ elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
193
+ itchat.send(reply.content, toUserName=receiver)
194
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
195
+ elif reply.type == ReplyType.VOICE:
196
+ itchat.send_file(reply.content, toUserName=receiver)
197
+ logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
198
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
199
+ img_url = reply.content
200
+ logger.debug(f"[WX] start download image, img_url={img_url}")
201
+ pic_res = requests.get(img_url, stream=True)
202
+ image_storage = io.BytesIO()
203
+ size = 0
204
+ for block in pic_res.iter_content(1024):
205
+ size += len(block)
206
+ image_storage.write(block)
207
+ logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
208
+ image_storage.seek(0)
209
+ itchat.send_image(image_storage, toUserName=receiver)
210
+ logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
211
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
212
+ image_storage = reply.content
213
+ image_storage.seek(0)
214
+ itchat.send_image(image_storage, toUserName=receiver)
215
+ logger.info("[WX] sendImage, receiver={}".format(receiver))
216
+ elif reply.type == ReplyType.FILE: # 新增文件回复类型
217
+ file_storage = reply.content
218
+ itchat.send_file(file_storage, toUserName=receiver)
219
+ logger.info("[WX] sendFile, receiver={}".format(receiver))
220
+ elif reply.type == ReplyType.VIDEO: # 新增视频回复类型
221
+ video_storage = reply.content
222
+ itchat.send_video(video_storage, toUserName=receiver)
223
+ logger.info("[WX] sendFile, receiver={}".format(receiver))
224
+ elif reply.type == ReplyType.VIDEO_URL: # 新增视频URL回复类型
225
+ video_url = reply.content
226
+ logger.debug(f"[WX] start download video, video_url={video_url}")
227
+ video_res = requests.get(video_url, stream=True)
228
+ video_storage = io.BytesIO()
229
+ size = 0
230
+ for block in video_res.iter_content(1024):
231
+ size += len(block)
232
+ video_storage.write(block)
233
+ logger.info(f"[WX] download video success, size={size}, video_url={video_url}")
234
+ video_storage.seek(0)
235
+ itchat.send_video(video_storage, toUserName=receiver)
236
+ logger.info("[WX] sendVideo url={}, receiver={}".format(video_url, receiver))
channel/wechat/wechat_message.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from bridge.context import ContextType
4
+ from channel.chat_message import ChatMessage
5
+ from common.log import logger
6
+ from common.tmp_dir import TmpDir
7
+ from lib import itchat
8
+ from lib.itchat.content import *
9
+
10
+ class WechatMessage(ChatMessage):
11
+ def __init__(self, itchat_msg, is_group=False):
12
+ super().__init__(itchat_msg)
13
+ self.msg_id = itchat_msg["MsgId"]
14
+ self.create_time = itchat_msg["CreateTime"]
15
+ self.is_group = is_group
16
+
17
+ if itchat_msg["Type"] == TEXT:
18
+ self.ctype = ContextType.TEXT
19
+ self.content = itchat_msg["Text"]
20
+ elif itchat_msg["Type"] == VOICE:
21
+ self.ctype = ContextType.VOICE
22
+ self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
23
+ self._prepare_fn = lambda: itchat_msg.download(self.content)
24
+ elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
25
+ self.ctype = ContextType.IMAGE
26
+ self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
27
+ self._prepare_fn = lambda: itchat_msg.download(self.content)
28
+ elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
29
+ if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
30
+ # 这里只能得到nickname, actual_user_id还是机器人的id
31
+ if "加入了群聊" in itchat_msg["Content"]:
32
+ self.ctype = ContextType.JOIN_GROUP
33
+ self.content = itchat_msg["Content"]
34
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
35
+ elif "加入群聊" in itchat_msg["Content"]:
36
+ self.ctype = ContextType.JOIN_GROUP
37
+ self.content = itchat_msg["Content"]
38
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
39
+
40
+ elif is_group and ("移出了群聊" in itchat_msg["Content"]):
41
+ self.ctype = ContextType.EXIT_GROUP
42
+ self.content = itchat_msg["Content"]
43
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
44
+
45
+ elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
46
+ self.ctype = ContextType.ACCEPT_FRIEND
47
+ self.content = itchat_msg["Content"]
48
+ elif "拍了拍我" in itchat_msg["Content"]:
49
+ self.ctype = ContextType.PATPAT
50
+ self.content = itchat_msg["Content"]
51
+ if is_group:
52
+ self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
53
+ else:
54
+ raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
55
+ elif itchat_msg["Type"] == ATTACHMENT:
56
+ self.ctype = ContextType.FILE
57
+ self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
58
+ self._prepare_fn = lambda: itchat_msg.download(self.content)
59
+ elif itchat_msg["Type"] == SHARING:
60
+ self.ctype = ContextType.SHARING
61
+ self.content = itchat_msg.get("Url")
62
+
63
+ else:
64
+ raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
65
+
66
+ self.from_user_id = itchat_msg["FromUserName"]
67
+ self.to_user_id = itchat_msg["ToUserName"]
68
+
69
+ user_id = itchat.instance.storageClass.userName
70
+ nickname = itchat.instance.storageClass.nickName
71
+
72
+ # 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
73
+ # 以下很繁琐,一句话总结:能填的都填了。
74
+ if self.from_user_id == user_id:
75
+ self.from_user_nickname = nickname
76
+ if self.to_user_id == user_id:
77
+ self.to_user_nickname = nickname
78
+ try: # 陌生人时候, User字段可能不存在
79
+ # my_msg 为True是表示是自己发送的消息
80
+ self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
81
+ itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
82
+ self.other_user_id = itchat_msg["User"]["UserName"]
83
+ self.other_user_nickname = itchat_msg["User"]["NickName"]
84
+ if self.other_user_id == self.from_user_id:
85
+ self.from_user_nickname = self.other_user_nickname
86
+ if self.other_user_id == self.to_user_id:
87
+ self.to_user_nickname = self.other_user_nickname
88
+ if itchat_msg["User"].get("Self"):
89
+ # 自身的展示名,当设置了群昵称时,该字段表示群昵称
90
+ self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
91
+ except KeyError as e: # 处理偶尔没有对方信息的情况
92
+ logger.warn("[WX]get other_user_id failed: " + str(e))
93
+ if self.from_user_id == user_id:
94
+ self.other_user_id = self.to_user_id
95
+ else:
96
+ self.other_user_id = self.from_user_id
97
+
98
+ if self.is_group:
99
+ self.is_at = itchat_msg["IsAt"]
100
+ self.actual_user_id = itchat_msg["ActualUserName"]
101
+ if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.EXIT_GROUP]:
102
+ self.actual_user_nickname = itchat_msg["ActualNickName"]
channel/wechat/wechaty_channel.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # encoding:utf-8
2
+
3
+ """
4
+ wechaty channel
5
+ Python Wechaty - https://github.com/wechaty/python-wechaty
6
+ """
7
+ import asyncio
8
+ import base64
9
+ import os
10
+ import time
11
+
12
+ from wechaty import Contact, Wechaty
13
+ from wechaty.user import Message
14
+ from wechaty_puppet import FileBox
15
+
16
+ from bridge.context import *
17
+ from bridge.context import Context
18
+ from bridge.reply import *
19
+ from channel.chat_channel import ChatChannel
20
+ from channel.wechat.wechaty_message import WechatyMessage
21
+ from common.log import logger
22
+ from common.singleton import singleton
23
+ from config import conf
24
+
25
+ try:
26
+ from voice.audio_convert import any_to_sil
27
+ except Exception as e:
28
+ pass
29
+
30
+
31
+ @singleton
32
+ class WechatyChannel(ChatChannel):
33
+ NOT_SUPPORT_REPLYTYPE = []
34
+
35
+ def __init__(self):
36
+ super().__init__()
37
+
38
+ def startup(self):
39
+ config = conf()
40
+ token = config.get("wechaty_puppet_service_token")
41
+ os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
42
+ asyncio.run(self.main())
43
+
44
+ async def main(self):
45
+ loop = asyncio.get_event_loop()
46
+ # 将asyncio的loop传入处理线程
47
+ self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
48
+ self.bot = Wechaty()
49
+ self.bot.on("login", self.on_login)
50
+ self.bot.on("message", self.on_message)
51
+ await self.bot.start()
52
+
53
+ async def on_login(self, contact: Contact):
54
+ self.user_id = contact.contact_id
55
+ self.name = contact.name
56
+ logger.info("[WX] login user={}".format(contact))
57
+
58
+ # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
59
+ def send(self, reply: Reply, context: Context):
60
+ receiver_id = context["receiver"]
61
+ loop = asyncio.get_event_loop()
62
+ if context["isgroup"]:
63
+ receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
64
+ else:
65
+ receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
66
+ msg = None
67
+ if reply.type == ReplyType.TEXT:
68
+ msg = reply.content
69
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
70
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
71
+ elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
72
+ msg = reply.content
73
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
74
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
75
+ elif reply.type == ReplyType.VOICE:
76
+ voiceLength = None
77
+ file_path = reply.content
78
+ sil_file = os.path.splitext(file_path)[0] + ".sil"
79
+ voiceLength = int(any_to_sil(file_path, sil_file))
80
+ if voiceLength >= 60000:
81
+ voiceLength = 60000
82
+ logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
83
+ # 发送语音
84
+ t = int(time.time())
85
+ msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
86
+ if voiceLength is not None:
87
+ msg.metadata["voiceLength"] = voiceLength
88
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
89
+ try:
90
+ os.remove(file_path)
91
+ if sil_file != file_path:
92
+ os.remove(sil_file)
93
+ except Exception as e:
94
+ pass
95
+ logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
96
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
97
+ img_url = reply.content
98
+ t = int(time.time())
99
+ msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
100
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
101
+ logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
102
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
103
+ image_storage = reply.content
104
+ image_storage.seek(0)
105
+ t = int(time.time())
106
+ msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
107
+ asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
108
+ logger.info("[WX] sendImage, receiver={}".format(receiver))
109
+
110
+ async def on_message(self, msg: Message):
111
+ """
112
+ listen for message event
113
+ """
114
+ try:
115
+ cmsg = await WechatyMessage(msg)
116
+ except NotImplementedError as e:
117
+ logger.debug("[WX] {}".format(e))
118
+ return
119
+ except Exception as e:
120
+ logger.exception("[WX] {}".format(e))
121
+ return
122
+ logger.debug("[WX] message:{}".format(cmsg))
123
+ room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
124
+ isgroup = room is not None
125
+ ctype = cmsg.ctype
126
+ context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
127
+ if context:
128
+ logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
129
+ self.produce(context)
channel/wechat/wechaty_message.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+
4
+ from wechaty import MessageType
5
+ from wechaty.user import Message
6
+
7
+ from bridge.context import ContextType
8
+ from channel.chat_message import ChatMessage
9
+ from common.log import logger
10
+ from common.tmp_dir import TmpDir
11
+
12
+
13
+ class aobject(object):
14
+ """Inheriting this class allows you to define an async __init__.
15
+
16
+ So you can create objects by doing something like `await MyClass(params)`
17
+ """
18
+
19
+ async def __new__(cls, *a, **kw):
20
+ instance = super().__new__(cls)
21
+ await instance.__init__(*a, **kw)
22
+ return instance
23
+
24
+ async def __init__(self):
25
+ pass
26
+
27
+
28
+ class WechatyMessage(ChatMessage, aobject):
29
+ async def __init__(self, wechaty_msg: Message):
30
+ super().__init__(wechaty_msg)
31
+
32
+ room = wechaty_msg.room()
33
+
34
+ self.msg_id = wechaty_msg.message_id
35
+ self.create_time = wechaty_msg.payload.timestamp
36
+ self.is_group = room is not None
37
+
38
+ if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
39
+ self.ctype = ContextType.TEXT
40
+ self.content = wechaty_msg.text()
41
+ elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
42
+ self.ctype = ContextType.VOICE
43
+ voice_file = await wechaty_msg.to_file_box()
44
+ self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
45
+
46
+ def func():
47
+ loop = asyncio.get_event_loop()
48
+ asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
49
+
50
+ self._prepare_fn = func
51
+
52
+ else:
53
+ raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
54
+
55
+ from_contact = wechaty_msg.talker() # 获取消息的发送者
56
+ self.from_user_id = from_contact.contact_id
57
+ self.from_user_nickname = from_contact.name
58
+
59
+ # group中的from和to,wechaty跟itchat含义不一样
60
+ # wecahty: from是消息实际发送者, to:所在群
61
+ # itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
62
+ # 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
63
+
64
+ if self.is_group:
65
+ self.to_user_id = room.room_id
66
+ self.to_user_nickname = await room.topic()
67
+ else:
68
+ to_contact = wechaty_msg.to()
69
+ self.to_user_id = to_contact.contact_id
70
+ self.to_user_nickname = to_contact.name
71
+
72
+ if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
73
+ self.other_user_id = self.to_user_id
74
+ self.other_user_nickname = self.to_user_nickname
75
+ else:
76
+ self.other_user_id = self.from_user_id
77
+ self.other_user_nickname = self.from_user_nickname
78
+
79
+ if self.is_group: # wechaty群聊中,实际发送用户就是from_user
80
+ self.is_at = await wechaty_msg.mention_self()
81
+ if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
82
+ name = wechaty_msg.wechaty.user_self().name
83
+ pattern = f"@{re.escape(name)}(\u2005|\u0020)"
84
+ if re.search(pattern, self.content):
85
+ logger.debug(f"wechaty message {self.msg_id} include at")
86
+ self.is_at = True
87
+
88
+ self.actual_user_id = self.from_user_id
89
+ self.actual_user_nickname = self.from_user_nickname
channel/wechatcom/README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 企业微信应用号channel
2
+
3
+ 企业微信官方提供了客服、应用等API,本channel使用的是企业微信的自建应用API的能力。
4
+
5
+ 因为未来可能还会开发客服能力,所以本channel的类型名叫作`wechatcom_app`。
6
+
7
+ `wechatcom_app` channel支持插件系统和图片声音交互等能力,除了无法加入群聊,作为个人使用的私人助理已绰绰有余。
8
+
9
+ ## 开始之前
10
+
11
+ - 在企业中确认自己拥有在企业内自建应用的权限。
12
+ - 如果没有权限或者是个人用户,也可创建未认证的企业。操作方式:登录手机企业微信,选择`创建/加入企业`来创建企业,类型请选择企业,企业名称可随意填写。
13
+ 未认证的企业有100人的服务人数上限,其他功能与认证企业没有差异。
14
+
15
+ 本channel需安装的依赖与公众号一致,需要安装`wechatpy`和`web.py`,它们包含在`requirements-optional.txt`中。
16
+
17
+ 此外,如果你是`Linux`系统,除了`ffmpeg`还需要安装`amr`编码器,否则会出现找不到编码器的错误,无法正常使用语音功能。
18
+
19
+ - Ubuntu/Debian
20
+
21
+ ```bash
22
+ apt-get install libavcodec-extra
23
+ ```
24
+
25
+ - Alpine
26
+
27
+ 需自行编译`ffmpeg`,在编译参数里加入`amr`编码器的支持
28
+
29
+ ## 使用方法
30
+
31
+ 1.查看企业ID
32
+
33
+ - 扫码登陆[企业微信后台](https://work.weixin.qq.com)
34
+ - 选择`我的企业`,点击`企业信息`,记住该`企业ID`
35
+
36
+ 2.创建自建应用
37
+
38
+ - 选择应用管理, 在自建区选创建应用来创建企业自建应用
39
+ - 上传应用logo,填写应用名称等项
40
+ - 创建应用后进入应用详情页面,记住`AgentId`和`Secert`
41
+
42
+ 3.配置应用
43
+
44
+ - 在详情页点击`企业可信IP`的配置(没看到可以不管),填入你服务器的公网IP,如果不知道可以先不填
45
+ - 点击`接收消息`下的启用API接收消息
46
+ - `URL`填写格式为`http://url:port/wxcomapp`,`port`是程序监听的端口,默认是9898
47
+ 如果是未认证的企业,url可直接使用服务器的IP。如果是认证企业,需要使用备案的域名,可使用二级域名。
48
+ - `Token`可随意填写,停留在这个页面
49
+ - 在程序根目录`config.json`中增加配置(**去掉注释**),`wechatcomapp_aes_key`是当前页面的`wechatcomapp_aes_key`
50
+
51
+ ```python
52
+ "channel_type": "wechatcom_app",
53
+ "wechatcom_corp_id": "", # 企业微信公司的corpID
54
+ "wechatcomapp_token": "", # 企业微信app的token
55
+ "wechatcomapp_port": 9898, # 企业微信app的服务端口, 不需要端口转发
56
+ "wechatcomapp_secret": "", # 企业微信app的secret
57
+ "wechatcomapp_agent_id": "", # 企业微信app的agent_id
58
+ "wechatcomapp_aes_key": "", # 企业微信app的aes_key
59
+ ```
60
+
61
+ - 运行程序,在页面中点击保存,保存成功说明验证成功
62
+
63
+ 4.连接个人微信
64
+
65
+ 选择`我的企业`,点击`微信插件`,下面有个邀请关注的二维码。微信扫码后,即可在微信中看到对应企业,在这里你便可以和机器人沟通。
66
+
67
+ 向机器人发送消息,如果日志里出现报错:
68
+
69
+ ```bash
70
+ Error code: 60020, message: "not allow to access from your ip, ...from ip: xx.xx.xx.xx"
71
+ ```
72
+
73
+ 意思是IP不可信,需要参考上一步的`企业可信IP`配置,把这里的IP加进去。
74
+
75
+ ~~### Railway部署方式~~(2023-06-08已失效)
76
+
77
+ ~~公众号不能在`Railway`上部署,但企业微信应用[可以](https://railway.app/template/-FHS--?referralCode=RC3znh)!~~
78
+
79
+ ~~填写配置后,将部署完成后的网址```**.railway.app/wxcomapp```,填写在上一步的URL中。发送信息后观察日志,把报错的IP加入到可信IP。(每次重启后都需要加入可信IP)~~
80
+
81
+ ## 测试体验
82
+
83
+ AIGC开放社区中已经部署了多个可免费使用的Bot,扫描下方的二维码会自动邀请你来体验。
84
+
85
+ <img width="200" src="../../docs/images/aigcopen.png">
channel/wechatcom/wechatcomapp_channel.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding=utf-8 -*-
2
+ import io
3
+ import os
4
+ import time
5
+
6
+ import requests
7
+ import web
8
+ from wechatpy.enterprise import create_reply, parse_message
9
+ from wechatpy.enterprise.crypto import WeChatCrypto
10
+ from wechatpy.enterprise.exceptions import InvalidCorpIdException
11
+ from wechatpy.exceptions import InvalidSignatureException, WeChatClientException
12
+
13
+ from bridge.context import Context
14
+ from bridge.reply import Reply, ReplyType
15
+ from channel.chat_channel import ChatChannel
16
+ from channel.wechatcom.wechatcomapp_client import WechatComAppClient
17
+ from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
18
+ from common.log import logger
19
+ from common.singleton import singleton
20
+ from common.utils import compress_imgfile, fsize, split_string_by_utf8_length
21
+ from config import conf, subscribe_msg
22
+ from voice.audio_convert import any_to_amr, split_audio
23
+
24
+ MAX_UTF8_LEN = 2048
25
+
26
+
27
+ @singleton
28
+ class WechatComAppChannel(ChatChannel):
29
+ NOT_SUPPORT_REPLYTYPE = []
30
+
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.corp_id = conf().get("wechatcom_corp_id")
34
+ self.secret = conf().get("wechatcomapp_secret")
35
+ self.agent_id = conf().get("wechatcomapp_agent_id")
36
+ self.token = conf().get("wechatcomapp_token")
37
+ self.aes_key = conf().get("wechatcomapp_aes_key")
38
+ print(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
39
+ logger.info(
40
+ "[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
41
+ )
42
+ self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id)
43
+ self.client = WechatComAppClient(self.corp_id, self.secret)
44
+
45
+ def startup(self):
46
+ # start message listener
47
+ urls = ("/wxcomapp", "channel.wechatcom.wechatcomapp_channel.Query")
48
+ app = web.application(urls, globals(), autoreload=False)
49
+ port = conf().get("wechatcomapp_port", 9898)
50
+ web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
51
+
52
+ def send(self, reply: Reply, context: Context):
53
+ receiver = context["receiver"]
54
+ if reply.type in [ReplyType.TEXT, ReplyType.ERROR, ReplyType.INFO]:
55
+ reply_text = reply.content
56
+ texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
57
+ if len(texts) > 1:
58
+ logger.info("[wechatcom] text too long, split into {} parts".format(len(texts)))
59
+ for i, text in enumerate(texts):
60
+ self.client.message.send_text(self.agent_id, receiver, text)
61
+ if i != len(texts) - 1:
62
+ time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
63
+ logger.info("[wechatcom] Do send text to {}: {}".format(receiver, reply_text))
64
+ elif reply.type == ReplyType.VOICE:
65
+ try:
66
+ media_ids = []
67
+ file_path = reply.content
68
+ amr_file = os.path.splitext(file_path)[0] + ".amr"
69
+ any_to_amr(file_path, amr_file)
70
+ duration, files = split_audio(amr_file, 60 * 1000)
71
+ if len(files) > 1:
72
+ logger.info("[wechatcom] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
73
+ for path in files:
74
+ response = self.client.media.upload("voice", open(path, "rb"))
75
+ logger.debug("[wechatcom] upload voice response: {}".format(response))
76
+ media_ids.append(response["media_id"])
77
+ except WeChatClientException as e:
78
+ logger.error("[wechatcom] upload voice failed: {}".format(e))
79
+ return
80
+ try:
81
+ os.remove(file_path)
82
+ if amr_file != file_path:
83
+ os.remove(amr_file)
84
+ except Exception:
85
+ pass
86
+ for media_id in media_ids:
87
+ self.client.message.send_voice(self.agent_id, receiver, media_id)
88
+ time.sleep(1)
89
+ logger.info("[wechatcom] sendVoice={}, receiver={}".format(reply.content, receiver))
90
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
91
+ img_url = reply.content
92
+ pic_res = requests.get(img_url, stream=True)
93
+ image_storage = io.BytesIO()
94
+ for block in pic_res.iter_content(1024):
95
+ image_storage.write(block)
96
+ sz = fsize(image_storage)
97
+ if sz >= 10 * 1024 * 1024:
98
+ logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
99
+ image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
100
+ logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
101
+ image_storage.seek(0)
102
+ try:
103
+ response = self.client.media.upload("image", image_storage)
104
+ logger.debug("[wechatcom] upload image response: {}".format(response))
105
+ except WeChatClientException as e:
106
+ logger.error("[wechatcom] upload image failed: {}".format(e))
107
+ return
108
+
109
+ self.client.message.send_image(self.agent_id, receiver, response["media_id"])
110
+ logger.info("[wechatcom] sendImage url={}, receiver={}".format(img_url, receiver))
111
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
112
+ image_storage = reply.content
113
+ sz = fsize(image_storage)
114
+ if sz >= 10 * 1024 * 1024:
115
+ logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
116
+ image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
117
+ logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
118
+ image_storage.seek(0)
119
+ try:
120
+ response = self.client.media.upload("image", image_storage)
121
+ logger.debug("[wechatcom] upload image response: {}".format(response))
122
+ except WeChatClientException as e:
123
+ logger.error("[wechatcom] upload image failed: {}".format(e))
124
+ return
125
+ self.client.message.send_image(self.agent_id, receiver, response["media_id"])
126
+ logger.info("[wechatcom] sendImage, receiver={}".format(receiver))
127
+
128
+
129
+ class Query:
130
+ def GET(self):
131
+ channel = WechatComAppChannel()
132
+ params = web.input()
133
+ logger.info("[wechatcom] receive params: {}".format(params))
134
+ try:
135
+ signature = params.msg_signature
136
+ timestamp = params.timestamp
137
+ nonce = params.nonce
138
+ echostr = params.echostr
139
+ echostr = channel.crypto.check_signature(signature, timestamp, nonce, echostr)
140
+ except InvalidSignatureException:
141
+ raise web.Forbidden()
142
+ return echostr
143
+
144
+ def POST(self):
145
+ channel = WechatComAppChannel()
146
+ params = web.input()
147
+ logger.info("[wechatcom] receive params: {}".format(params))
148
+ try:
149
+ signature = params.msg_signature
150
+ timestamp = params.timestamp
151
+ nonce = params.nonce
152
+ message = channel.crypto.decrypt_message(web.data(), signature, timestamp, nonce)
153
+ except (InvalidSignatureException, InvalidCorpIdException):
154
+ raise web.Forbidden()
155
+ msg = parse_message(message)
156
+ logger.debug("[wechatcom] receive message: {}, msg= {}".format(message, msg))
157
+ if msg.type == "event":
158
+ if msg.event == "subscribe":
159
+ reply_content = subscribe_msg()
160
+ if reply_content:
161
+ reply = create_reply(reply_content, msg).render()
162
+ res = channel.crypto.encrypt_message(reply, nonce, timestamp)
163
+ return res
164
+ else:
165
+ try:
166
+ wechatcom_msg = WechatComAppMessage(msg, client=channel.client)
167
+ except NotImplementedError as e:
168
+ logger.debug("[wechatcom] " + str(e))
169
+ return "success"
170
+ context = channel._compose_context(
171
+ wechatcom_msg.ctype,
172
+ wechatcom_msg.content,
173
+ isgroup=False,
174
+ msg=wechatcom_msg,
175
+ )
176
+ if context:
177
+ channel.produce(context)
178
+ return "success"
channel/wechatcom/wechatcomapp_client.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+
4
+ from wechatpy.enterprise import WeChatClient
5
+
6
+
7
+ class WechatComAppClient(WeChatClient):
8
+ def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True):
9
+ super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry)
10
+ self.fetch_access_token_lock = threading.Lock()
11
+
12
+ def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
13
+ with self.fetch_access_token_lock:
14
+ access_token = self.session.get(self.access_token_key)
15
+ if access_token:
16
+ if not self.expires_at:
17
+ return access_token
18
+ timestamp = time.time()
19
+ if self.expires_at - timestamp > 60:
20
+ return access_token
21
+ return super().fetch_access_token()
channel/wechatcom/wechatcomapp_message.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from wechatpy.enterprise import WeChatClient
2
+
3
+ from bridge.context import ContextType
4
+ from channel.chat_message import ChatMessage
5
+ from common.log import logger
6
+ from common.tmp_dir import TmpDir
7
+
8
+
9
+ class WechatComAppMessage(ChatMessage):
10
+ def __init__(self, msg, client: WeChatClient, is_group=False):
11
+ super().__init__(msg)
12
+ self.msg_id = msg.id
13
+ self.create_time = msg.time
14
+ self.is_group = is_group
15
+
16
+ if msg.type == "text":
17
+ self.ctype = ContextType.TEXT
18
+ self.content = msg.content
19
+ elif msg.type == "voice":
20
+ self.ctype = ContextType.VOICE
21
+ self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
22
+
23
+ def download_voice():
24
+ # 如果响应状态码是200,则将响应内容写入本地文件
25
+ response = client.media.download(msg.media_id)
26
+ if response.status_code == 200:
27
+ with open(self.content, "wb") as f:
28
+ f.write(response.content)
29
+ else:
30
+ logger.info(f"[wechatcom] Failed to download voice file, {response.content}")
31
+
32
+ self._prepare_fn = download_voice
33
+ elif msg.type == "image":
34
+ self.ctype = ContextType.IMAGE
35
+ self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
36
+
37
+ def download_image():
38
+ # 如果响应状态码是200,则将响应内容写入本地文件
39
+ response = client.media.download(msg.media_id)
40
+ if response.status_code == 200:
41
+ with open(self.content, "wb") as f:
42
+ f.write(response.content)
43
+ else:
44
+ logger.info(f"[wechatcom] Failed to download image file, {response.content}")
45
+
46
+ self._prepare_fn = download_image
47
+ else:
48
+ raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
49
+
50
+ self.from_user_id = msg.source
51
+ self.to_user_id = msg.target
52
+ self.other_user_id = msg.source
channel/wechatmp/README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 微信公众号channel
2
+
3
+ 鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。
4
+ 目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。
5
+
6
+ ## 使用方法(订阅号,服务号类似)
7
+
8
+ 在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。
9
+
10
+ 此外,需要在我们的服务器上安装python的web框架web.py和wechatpy。
11
+ 以ubuntu为例(在ubuntu 22.04上测试):
12
+ ```
13
+ pip3 install web.py
14
+ pip3 install wechatpy
15
+ ```
16
+
17
+ 然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。
18
+
19
+ 然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。`URL`填写格式为`http://url/wx`,可使用IP(成功几率看脸),`Token`是你自己编的一个特定的令牌。消息加解密方式如果选择了需要加密的模式,需要在配置中填写`wechatmp_aes_key`。
20
+
21
+ 相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
22
+ ```
23
+ "channel_type": "wechatmp", # 如果通过了微信认证,将"wechatmp"替换为"wechatmp_service",可极大的优化使用体验
24
+ "wechatmp_token": "xxxx", # 微信公众平台的Token
25
+ "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
26
+ "wechatmp_app_id": "xxxx", # 微信公众平台的appID
27
+ "wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret
28
+ "wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
29
+ "single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀
30
+ "single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀
31
+ "plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。
32
+ ```
33
+ 然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口:
34
+ ```
35
+ sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
36
+ sudo iptables-save > /etc/iptables/rules.v4
37
+ ```
38
+ 第二个方法是让python程序直接监听80端口,在配置文件中设置`"wechatmp_port": 80` ,在linux上需要使用`sudo python3 app.py`启动程序。然而这会导致一系列环境和权限问题,因此不是推荐的方法。
39
+
40
+ 443端口同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`中需要修改相应的证书路径。
41
+
42
+ 程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
43
+ 随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。
44
+
45
+ 之后需要在公众号开发信息下将本机IP加入到IP白名单。
46
+
47
+ 不然在启用后,发送语音、图片等消息可能会遇到如下报错:
48
+ ```
49
+ 'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid
50
+ ```
51
+
52
+
53
+ ## 个人微信公众号的限制
54
+ 由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
55
+
56
+ 另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答进行了拆分,以满足限制。
57
+
58
+ ## 私有api_key
59
+ 公共api有访问频率限制(免费账号每分钟最多3次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
60
+
61
+ ## 语音输入
62
+ 利用微信自带的语音识别功能,提供语音输入能力。需要在公众号管理页面的“设置与开发”->“接口权限”页面开启“接收语音识别结果”。
63
+
64
+ ## 语音回复
65
+ 请在配置文件中添加以下词条:
66
+ ```
67
+ "voice_reply_voice": true,
68
+ ```
69
+ 这样公众号将会用语音回复语音消息,实现语音对话。
70
+
71
+ 默认的语音合成引擎是`google`,它是免费使用的。
72
+
73
+ 如果要选择其他的语音合成引擎,请添加以下配置项:
74
+ ```
75
+ "text_to_voice": "pytts"
76
+ ```
77
+
78
+ pytts是本地的语音合成引擎。还支持baidu,azure,这些你需要自行配置相关的依赖和key。
79
+
80
+ 如果使用pytts,在ubuntu上需要安装如下依赖:
81
+ ```
82
+ sudo apt update
83
+ sudo apt install espeak
84
+ sudo apt install ffmpeg
85
+ python3 -m pip install pyttsx3
86
+ ```
87
+ 不是很建议开启pytts语音回复,因为它是离线本地计算,算的慢会拖垮服务器,且声音不好听。
88
+
89
+ ## 图片回复
90
+ 现在认证公众号和非认证公众号都可以实现的图片和语音回复。但是非认证公众号使用了永久素材接口,每天有1000次的调用上限(每个月有10次重置机会,程序中已设定遇到上限会自动重置),且永久素材库存也有上限。因此对于非认证公众号,我们会在回复图片或者语音消息后的10秒内从永久素材库存内删除该素材。
91
+
92
+ ## 测试
93
+ 目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)),感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有详尽测试。百度的接口暂未测试。[wechatmp-stable分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp-stable)是较稳定的上个版本,但也缺少最新的功能支持。
94
+
95
+ ## TODO
96
+ - [x] 语音输入
97
+ - [x] 图片输入
98
+ - [x] 使用临时素材接口提供认证公众号的图片和语音回复
99
+ - [x] 使用永久素材接口提供未认证公众号的图片和语音回复
100
+ - [ ] 高并发支持
channel/wechatmp/active_reply.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import web
4
+ from wechatpy import parse_message
5
+ from wechatpy.replies import create_reply
6
+
7
+ from bridge.context import *
8
+ from bridge.reply import *
9
+ from channel.wechatmp.common import *
10
+ from channel.wechatmp.wechatmp_channel import WechatMPChannel
11
+ from channel.wechatmp.wechatmp_message import WeChatMPMessage
12
+ from common.log import logger
13
+ from config import conf, subscribe_msg
14
+
15
+
16
+ # This class is instantiated once per query
17
+ class Query:
18
+ def GET(self):
19
+ return verify_server(web.input())
20
+
21
+ def POST(self):
22
+ # Make sure to return the instance that first created, @singleton will do that.
23
+ try:
24
+ args = web.input()
25
+ verify_server(args)
26
+ channel = WechatMPChannel()
27
+ message = web.data()
28
+ encrypt_func = lambda x: x
29
+ if args.get("encrypt_type") == "aes":
30
+ logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
31
+ if not channel.crypto:
32
+ raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
33
+ message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
34
+ encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
35
+ else:
36
+ logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
37
+ msg = parse_message(message)
38
+ if msg.type in ["text", "voice", "image"]:
39
+ wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
40
+ from_user = wechatmp_msg.from_user_id
41
+ content = wechatmp_msg.content
42
+ message_id = wechatmp_msg.msg_id
43
+
44
+ logger.info(
45
+ "[wechatmp] {}:{} Receive post query {} {}: {}".format(
46
+ web.ctx.env.get("REMOTE_ADDR"),
47
+ web.ctx.env.get("REMOTE_PORT"),
48
+ from_user,
49
+ message_id,
50
+ content,
51
+ )
52
+ )
53
+ if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
54
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
55
+ else:
56
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
57
+ if context:
58
+ channel.produce(context)
59
+ # The reply will be sent by channel.send() in another thread
60
+ return "success"
61
+ elif msg.type == "event":
62
+ logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
63
+ if msg.event in ["subscribe", "subscribe_scan"]:
64
+ reply_text = subscribe_msg()
65
+ if reply_text:
66
+ replyPost = create_reply(reply_text, msg)
67
+ return encrypt_func(replyPost.render())
68
+ else:
69
+ return "success"
70
+ else:
71
+ logger.info("暂且不处理")
72
+ return "success"
73
+ except Exception as exc:
74
+ logger.exception(exc)
75
+ return exc
channel/wechatmp/common.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import web
2
+ from wechatpy.crypto import WeChatCrypto
3
+ from wechatpy.exceptions import InvalidSignatureException
4
+ from wechatpy.utils import check_signature
5
+
6
+ from config import conf
7
+
8
+ MAX_UTF8_LEN = 2048
9
+
10
+
11
+ class WeChatAPIException(Exception):
12
+ pass
13
+
14
+
15
+ def verify_server(data):
16
+ try:
17
+ signature = data.signature
18
+ timestamp = data.timestamp
19
+ nonce = data.nonce
20
+ echostr = data.get("echostr", None)
21
+ token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写
22
+ check_signature(token, signature, timestamp, nonce)
23
+ return echostr
24
+ except InvalidSignatureException:
25
+ raise web.Forbidden("Invalid signature")
26
+ except Exception as e:
27
+ raise web.Forbidden(str(e))
channel/wechatmp/passive_reply.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+
4
+ import web
5
+ from wechatpy import parse_message
6
+ from wechatpy.replies import ImageReply, VoiceReply, create_reply
7
+ import textwrap
8
+ from bridge.context import *
9
+ from bridge.reply import *
10
+ from channel.wechatmp.common import *
11
+ from channel.wechatmp.wechatmp_channel import WechatMPChannel
12
+ from channel.wechatmp.wechatmp_message import WeChatMPMessage
13
+ from common.log import logger
14
+ from common.utils import split_string_by_utf8_length
15
+ from config import conf, subscribe_msg
16
+
17
+
18
+ # This class is instantiated once per query
19
+ class Query:
20
+ def GET(self):
21
+ return verify_server(web.input())
22
+
23
+ def POST(self):
24
+ try:
25
+ args = web.input()
26
+ verify_server(args)
27
+ request_time = time.time()
28
+ channel = WechatMPChannel()
29
+ message = web.data()
30
+ encrypt_func = lambda x: x
31
+ if args.get("encrypt_type") == "aes":
32
+ logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
33
+ if not channel.crypto:
34
+ raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
35
+ message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
36
+ encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
37
+ else:
38
+ logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
39
+ msg = parse_message(message)
40
+ if msg.type in ["text", "voice", "image"]:
41
+ wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
42
+ from_user = wechatmp_msg.from_user_id
43
+ content = wechatmp_msg.content
44
+ message_id = wechatmp_msg.msg_id
45
+
46
+ supported = True
47
+ if "【收到不支持的消息类型,暂无法显示】" in content:
48
+ supported = False # not supported, used to refresh
49
+
50
+ # New request
51
+ if (
52
+ channel.cache_dict.get(from_user) is None
53
+ and from_user not in channel.running
54
+ or content.startswith("#")
55
+ and message_id not in channel.request_cnt # insert the godcmd
56
+ ):
57
+ # The first query begin
58
+ if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
59
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
60
+ else:
61
+ context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
62
+ logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
63
+
64
+ if supported and context:
65
+ channel.running.add(from_user)
66
+ channel.produce(context)
67
+ else:
68
+ trigger_prefix = conf().get("single_chat_prefix", [""])[0]
69
+ if trigger_prefix or not supported:
70
+ if trigger_prefix:
71
+ reply_text = textwrap.dedent(
72
+ f"""\
73
+ 请输入'{trigger_prefix}'接你想说的话跟我说话。
74
+ 例如:
75
+ {trigger_prefix}你好,很高兴见到你。"""
76
+ )
77
+ else:
78
+ reply_text = textwrap.dedent(
79
+ """\
80
+ 你好,很高兴见到你。
81
+ 请跟我说话吧。"""
82
+ )
83
+ else:
84
+ logger.error(f"[wechatmp] unknown error")
85
+ reply_text = textwrap.dedent(
86
+ """\
87
+ 未知错误,请稍后再试"""
88
+ )
89
+
90
+ replyPost = create_reply(reply_text, msg)
91
+ return encrypt_func(replyPost.render())
92
+
93
+ # Wechat official server will request 3 times (5 seconds each), with the same message_id.
94
+ # Because the interval is 5 seconds, here assumed that do not have multithreading problems.
95
+ request_cnt = channel.request_cnt.get(message_id, 0) + 1
96
+ channel.request_cnt[message_id] = request_cnt
97
+ logger.info(
98
+ "[wechatmp] Request {} from {} {} {}:{}\n{}".format(
99
+ request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
100
+ )
101
+ )
102
+
103
+ task_running = True
104
+ waiting_until = request_time + 4
105
+ while time.time() < waiting_until:
106
+ if from_user in channel.running:
107
+ time.sleep(0.1)
108
+ else:
109
+ task_running = False
110
+ break
111
+
112
+ reply_text = ""
113
+ if task_running:
114
+ if request_cnt < 3:
115
+ # waiting for timeout (the POST request will be closed by Wechat official server)
116
+ time.sleep(2)
117
+ # and do nothing, waiting for the next request
118
+ return "success"
119
+ else: # request_cnt == 3:
120
+ # return timeout message
121
+ reply_text = "【正在思考中,回复任意文字尝试获取回复】"
122
+ replyPost = create_reply(reply_text, msg)
123
+ return encrypt_func(replyPost.render())
124
+
125
+ # reply is ready
126
+ channel.request_cnt.pop(message_id)
127
+
128
+ # no return because of bandwords or other reasons
129
+ if from_user not in channel.cache_dict and from_user not in channel.running:
130
+ return "success"
131
+
132
+ # Only one request can access to the cached data
133
+ try:
134
+ (reply_type, reply_content) = channel.cache_dict[from_user].pop(0)
135
+ if not channel.cache_dict[from_user]: # If popping the message makes the list empty, delete the user entry from cache
136
+ del channel.cache_dict[from_user]
137
+ except IndexError:
138
+ return "success"
139
+
140
+ if reply_type == "text":
141
+ if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
142
+ reply_text = reply_content
143
+ else:
144
+ continue_text = "\n【未完待续,回复任意文字以继续】"
145
+ splits = split_string_by_utf8_length(
146
+ reply_content,
147
+ MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
148
+ max_split=1,
149
+ )
150
+ reply_text = splits[0] + continue_text
151
+ channel.cache_dict[from_user].append(("text", splits[1]))
152
+
153
+ logger.info(
154
+ "[wechatmp] Request {} do send to {} {}: {}\n{}".format(
155
+ request_cnt,
156
+ from_user,
157
+ message_id,
158
+ content,
159
+ reply_text,
160
+ )
161
+ )
162
+ replyPost = create_reply(reply_text, msg)
163
+ return encrypt_func(replyPost.render())
164
+
165
+ elif reply_type == "voice":
166
+ media_id = reply_content
167
+ asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
168
+ logger.info(
169
+ "[wechatmp] Request {} do send to {} {}: {} voice media_id {}".format(
170
+ request_cnt,
171
+ from_user,
172
+ message_id,
173
+ content,
174
+ media_id,
175
+ )
176
+ )
177
+ replyPost = VoiceReply(message=msg)
178
+ replyPost.media_id = media_id
179
+ return encrypt_func(replyPost.render())
180
+
181
+ elif reply_type == "image":
182
+ media_id = reply_content
183
+ asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
184
+ logger.info(
185
+ "[wechatmp] Request {} do send to {} {}: {} image media_id {}".format(
186
+ request_cnt,
187
+ from_user,
188
+ message_id,
189
+ content,
190
+ media_id,
191
+ )
192
+ )
193
+ replyPost = ImageReply(message=msg)
194
+ replyPost.media_id = media_id
195
+ return encrypt_func(replyPost.render())
196
+
197
+ elif msg.type == "event":
198
+ logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
199
+ if msg.event in ["subscribe", "subscribe_scan"]:
200
+ reply_text = subscribe_msg()
201
+ if reply_text:
202
+ replyPost = create_reply(reply_text, msg)
203
+ return encrypt_func(replyPost.render())
204
+ else:
205
+ return "success"
206
+ else:
207
+ logger.info("暂且不处理")
208
+ return "success"
209
+ except Exception as exc:
210
+ logger.exception(exc)
211
+ return exc
channel/wechatmp/wechatmp_channel.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import asyncio
3
+ import imghdr
4
+ import io
5
+ import os
6
+ import threading
7
+ import time
8
+
9
+ import requests
10
+ import web
11
+ from wechatpy.crypto import WeChatCrypto
12
+ from wechatpy.exceptions import WeChatClientException
13
+ from collections import defaultdict
14
+
15
+ from bridge.context import *
16
+ from bridge.reply import *
17
+ from channel.chat_channel import ChatChannel
18
+ from channel.wechatmp.common import *
19
+ from channel.wechatmp.wechatmp_client import WechatMPClient
20
+ from common.log import logger
21
+ from common.singleton import singleton
22
+ from common.utils import split_string_by_utf8_length
23
+ from config import conf
24
+ from voice.audio_convert import any_to_mp3, split_audio
25
+
26
+ # If using SSL, uncomment the following lines, and modify the certificate path.
27
+ # from cheroot.server import HTTPServer
28
+ # from cheroot.ssl.builtin import BuiltinSSLAdapter
29
+ # HTTPServer.ssl_adapter = BuiltinSSLAdapter(
30
+ # certificate='/ssl/cert.pem',
31
+ # private_key='/ssl/cert.key')
32
+
33
+
34
+ @singleton
35
+ class WechatMPChannel(ChatChannel):
36
+ def __init__(self, passive_reply=True):
37
+ super().__init__()
38
+ self.passive_reply = passive_reply
39
+ self.NOT_SUPPORT_REPLYTYPE = []
40
+ appid = conf().get("wechatmp_app_id")
41
+ secret = conf().get("wechatmp_app_secret")
42
+ token = conf().get("wechatmp_token")
43
+ aes_key = conf().get("wechatmp_aes_key")
44
+ self.client = WechatMPClient(appid, secret)
45
+ self.crypto = None
46
+ if aes_key:
47
+ self.crypto = WeChatCrypto(token, aes_key, appid)
48
+ if self.passive_reply:
49
+ # Cache the reply to the user's first message
50
+ self.cache_dict = defaultdict(list)
51
+ # Record whether the current message is being processed
52
+ self.running = set()
53
+ # Count the request from wechat official server by message_id
54
+ self.request_cnt = dict()
55
+ # The permanent media need to be deleted to avoid media number limit
56
+ self.delete_media_loop = asyncio.new_event_loop()
57
+ t = threading.Thread(target=self.start_loop, args=(self.delete_media_loop,))
58
+ t.setDaemon(True)
59
+ t.start()
60
+
61
+ def startup(self):
62
+ if self.passive_reply:
63
+ urls = ("/wx", "channel.wechatmp.passive_reply.Query")
64
+ else:
65
+ urls = ("/wx", "channel.wechatmp.active_reply.Query")
66
+ app = web.application(urls, globals(), autoreload=False)
67
+ port = conf().get("wechatmp_port", 8080)
68
+ web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
69
+
70
+ def start_loop(self, loop):
71
+ asyncio.set_event_loop(loop)
72
+ loop.run_forever()
73
+
74
+ async def delete_media(self, media_id):
75
+ logger.debug("[wechatmp] permanent media {} will be deleted in 10s".format(media_id))
76
+ await asyncio.sleep(10)
77
+ self.client.material.delete(media_id)
78
+ logger.info("[wechatmp] permanent media {} has been deleted".format(media_id))
79
+
80
+ def send(self, reply: Reply, context: Context):
81
+ receiver = context["receiver"]
82
+ if self.passive_reply:
83
+ if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
84
+ reply_text = reply.content
85
+ logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
86
+ self.cache_dict[receiver].append(("text", reply_text))
87
+ elif reply.type == ReplyType.VOICE:
88
+ voice_file_path = reply.content
89
+ duration, files = split_audio(voice_file_path, 60 * 1000)
90
+ if len(files) > 1:
91
+ logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
92
+
93
+ for path in files:
94
+ # support: <2M, <60s, mp3/wma/wav/amr
95
+ try:
96
+ with open(path, "rb") as f:
97
+ response = self.client.material.add("voice", f)
98
+ logger.debug("[wechatmp] upload voice response: {}".format(response))
99
+ f_size = os.fstat(f.fileno()).st_size
100
+ time.sleep(1.0 + 2 * f_size / 1024 / 1024)
101
+ # todo check media_id
102
+ except WeChatClientException as e:
103
+ logger.error("[wechatmp] upload voice failed: {}".format(e))
104
+ return
105
+ media_id = response["media_id"]
106
+ logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
107
+ self.cache_dict[receiver].append(("voice", media_id))
108
+
109
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
110
+ img_url = reply.content
111
+ pic_res = requests.get(img_url, stream=True)
112
+ image_storage = io.BytesIO()
113
+ for block in pic_res.iter_content(1024):
114
+ image_storage.write(block)
115
+ image_storage.seek(0)
116
+ image_type = imghdr.what(image_storage)
117
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
118
+ content_type = "image/" + image_type
119
+ try:
120
+ response = self.client.material.add("image", (filename, image_storage, content_type))
121
+ logger.debug("[wechatmp] upload image response: {}".format(response))
122
+ except WeChatClientException as e:
123
+ logger.error("[wechatmp] upload image failed: {}".format(e))
124
+ return
125
+ media_id = response["media_id"]
126
+ logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
127
+ self.cache_dict[receiver].append(("image", media_id))
128
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
129
+ image_storage = reply.content
130
+ image_storage.seek(0)
131
+ image_type = imghdr.what(image_storage)
132
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
133
+ content_type = "image/" + image_type
134
+ try:
135
+ response = self.client.material.add("image", (filename, image_storage, content_type))
136
+ logger.debug("[wechatmp] upload image response: {}".format(response))
137
+ except WeChatClientException as e:
138
+ logger.error("[wechatmp] upload image failed: {}".format(e))
139
+ return
140
+ media_id = response["media_id"]
141
+ logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
142
+ self.cache_dict[receiver].append(("image", media_id))
143
+ else:
144
+ if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
145
+ reply_text = reply.content
146
+ texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
147
+ if len(texts) > 1:
148
+ logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
149
+ for i, text in enumerate(texts):
150
+ self.client.message.send_text(receiver, text)
151
+ if i != len(texts) - 1:
152
+ time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
153
+ logger.info("[wechatmp] Do send text to {}: {}".format(receiver, reply_text))
154
+ elif reply.type == ReplyType.VOICE:
155
+ try:
156
+ file_path = reply.content
157
+ file_name = os.path.basename(file_path)
158
+ file_type = os.path.splitext(file_name)[1]
159
+ if file_type == ".mp3":
160
+ file_type = "audio/mpeg"
161
+ elif file_type == ".amr":
162
+ file_type = "audio/amr"
163
+ else:
164
+ mp3_file = os.path.splitext(file_path)[0] + ".mp3"
165
+ any_to_mp3(file_path, mp3_file)
166
+ file_path = mp3_file
167
+ file_name = os.path.basename(file_path)
168
+ file_type = "audio/mpeg"
169
+ logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type))
170
+ media_ids = []
171
+ duration, files = split_audio(file_path, 60 * 1000)
172
+ if len(files) > 1:
173
+ logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
174
+ for path in files:
175
+ # support: <2M, <60s, AMR\MP3
176
+ response = self.client.media.upload("voice", (os.path.basename(path), open(path, "rb"), file_type))
177
+ logger.debug("[wechatcom] upload voice response: {}".format(response))
178
+ media_ids.append(response["media_id"])
179
+ os.remove(path)
180
+ except WeChatClientException as e:
181
+ logger.error("[wechatmp] upload voice failed: {}".format(e))
182
+ return
183
+
184
+ try:
185
+ os.remove(file_path)
186
+ except Exception:
187
+ pass
188
+
189
+ for media_id in media_ids:
190
+ self.client.message.send_voice(receiver, media_id)
191
+ time.sleep(1)
192
+ logger.info("[wechatmp] Do send voice to {}".format(receiver))
193
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
194
+ img_url = reply.content
195
+ pic_res = requests.get(img_url, stream=True)
196
+ image_storage = io.BytesIO()
197
+ for block in pic_res.iter_content(1024):
198
+ image_storage.write(block)
199
+ image_storage.seek(0)
200
+ image_type = imghdr.what(image_storage)
201
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
202
+ content_type = "image/" + image_type
203
+ try:
204
+ response = self.client.media.upload("image", (filename, image_storage, content_type))
205
+ logger.debug("[wechatmp] upload image response: {}".format(response))
206
+ except WeChatClientException as e:
207
+ logger.error("[wechatmp] upload image failed: {}".format(e))
208
+ return
209
+ self.client.message.send_image(receiver, response["media_id"])
210
+ logger.info("[wechatmp] Do send image to {}".format(receiver))
211
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
212
+ image_storage = reply.content
213
+ image_storage.seek(0)
214
+ image_type = imghdr.what(image_storage)
215
+ filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
216
+ content_type = "image/" + image_type
217
+ try:
218
+ response = self.client.media.upload("image", (filename, image_storage, content_type))
219
+ logger.debug("[wechatmp] upload image response: {}".format(response))
220
+ except WeChatClientException as e:
221
+ logger.error("[wechatmp] upload image failed: {}".format(e))
222
+ return
223
+ self.client.message.send_image(receiver, response["media_id"])
224
+ logger.info("[wechatmp] Do send image to {}".format(receiver))
225
+ return
226
+
227
+ def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
228
+ logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
229
+ if self.passive_reply:
230
+ self.running.remove(session_id)
231
+
232
+ def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
233
+ logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
234
+ if self.passive_reply:
235
+ assert session_id not in self.cache_dict
236
+ self.running.remove(session_id)
channel/wechatmp/wechatmp_client.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+
4
+ from wechatpy.client import WeChatClient
5
+ from wechatpy.exceptions import APILimitedException
6
+
7
+ from channel.wechatmp.common import *
8
+ from common.log import logger
9
+
10
+
11
+ class WechatMPClient(WeChatClient):
12
+ def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
13
+ super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
14
+ self.fetch_access_token_lock = threading.Lock()
15
+ self.clear_quota_lock = threading.Lock()
16
+ self.last_clear_quota_time = -1
17
+
18
+ def clear_quota(self):
19
+ return self.post("clear_quota", data={"appid": self.appid})
20
+
21
+ def clear_quota_v2(self):
22
+ return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})
23
+
24
+ def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
25
+ with self.fetch_access_token_lock:
26
+ access_token = self.session.get(self.access_token_key)
27
+ if access_token:
28
+ if not self.expires_at:
29
+ return access_token
30
+ timestamp = time.time()
31
+ if self.expires_at - timestamp > 60:
32
+ return access_token
33
+ return super().fetch_access_token()
34
+
35
+ def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
36
+ try:
37
+ return super()._request(method, url_or_endpoint, **kwargs)
38
+ except APILimitedException as e:
39
+ logger.error("[wechatmp] API quata has been used up. {}".format(e))
40
+ if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
41
+ with self.clear_quota_lock:
42
+ if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
43
+ self.last_clear_quota_time = time.time()
44
+ response = self.clear_quota_v2()
45
+ logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
46
+ return super()._request(method, url_or_endpoint, **kwargs)
47
+ else:
48
+ logger.error("[wechatmp] last clear quota time is {}, less than 60s, skip clear quota")
49
+ raise e
channel/wechatmp/wechatmp_message.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-#
2
+
3
+ from bridge.context import ContextType
4
+ from channel.chat_message import ChatMessage
5
+ from common.log import logger
6
+ from common.tmp_dir import TmpDir
7
+
8
+
9
+ class WeChatMPMessage(ChatMessage):
10
+ def __init__(self, msg, client=None):
11
+ super().__init__(msg)
12
+ self.msg_id = msg.id
13
+ self.create_time = msg.time
14
+ self.is_group = False
15
+
16
+ if msg.type == "text":
17
+ self.ctype = ContextType.TEXT
18
+ self.content = msg.content
19
+ elif msg.type == "voice":
20
+ if msg.recognition == None:
21
+ self.ctype = ContextType.VOICE
22
+ self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
23
+
24
+ def download_voice():
25
+ # 如果响应状态码是200,则将响应内容写入本地文件
26
+ response = client.media.download(msg.media_id)
27
+ if response.status_code == 200:
28
+ with open(self.content, "wb") as f:
29
+ f.write(response.content)
30
+ else:
31
+ logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
32
+
33
+ self._prepare_fn = download_voice
34
+ else:
35
+ self.ctype = ContextType.TEXT
36
+ self.content = msg.recognition
37
+ elif msg.type == "image":
38
+ self.ctype = ContextType.IMAGE
39
+ self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
40
+
41
+ def download_image():
42
+ # 如果响应状态码是200,则将响应内容写入本地文件
43
+ response = client.media.download(msg.media_id)
44
+ if response.status_code == 200:
45
+ with open(self.content, "wb") as f:
46
+ f.write(response.content)
47
+ else:
48
+ logger.info(f"[wechatmp] Failed to download image file, {response.content}")
49
+
50
+ self._prepare_fn = download_image
51
+ else:
52
+ raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
53
+
54
+ self.from_user_id = msg.source
55
+ self.to_user_id = msg.target
56
+ self.other_user_id = msg.source
channel/wework/run.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ os.environ['ntwork_LOG'] = "ERROR"
4
+ import ntwork
5
+
6
+ wework = ntwork.WeWork()
7
+
8
+
9
+ def forever():
10
+ try:
11
+ while True:
12
+ time.sleep(0.1)
13
+ except KeyboardInterrupt:
14
+ ntwork.exit_()
15
+ os._exit(0)
16
+
17
+
channel/wework/wework_channel.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import random
4
+ import tempfile
5
+ import threading
6
+ os.environ['ntwork_LOG'] = "ERROR"
7
+ import ntwork
8
+ import requests
9
+ import uuid
10
+
11
+ from bridge.context import *
12
+ from bridge.reply import *
13
+ from channel.chat_channel import ChatChannel
14
+ from channel.wework.wework_message import *
15
+ from channel.wework.wework_message import WeworkMessage
16
+ from common.singleton import singleton
17
+ from common.log import logger
18
+ from common.time_check import time_checker
19
+ from common.utils import compress_imgfile, fsize
20
+ from config import conf
21
+ from channel.wework.run import wework
22
+ from channel.wework import run
23
+ from PIL import Image
24
+
25
+
26
+ def get_wxid_by_name(room_members, group_wxid, name):
27
+ if group_wxid in room_members:
28
+ for member in room_members[group_wxid]['member_list']:
29
+ if member['room_nickname'] == name or member['username'] == name:
30
+ return member['user_id']
31
+ return None # 如果没有找到对应的group_wxid或name,则返回None
32
+
33
+
34
+ def download_and_compress_image(url, filename, quality=30):
35
+ # 确定保存图片的目录
36
+ directory = os.path.join(os.getcwd(), "tmp")
37
+ # 如果目录不存在,则创建目录
38
+ if not os.path.exists(directory):
39
+ os.makedirs(directory)
40
+
41
+ # 下载图片
42
+ pic_res = requests.get(url, stream=True)
43
+ image_storage = io.BytesIO()
44
+ for block in pic_res.iter_content(1024):
45
+ image_storage.write(block)
46
+
47
+ # 检查图片大小并可能进行压缩
48
+ sz = fsize(image_storage)
49
+ if sz >= 10 * 1024 * 1024: # 如果图片大于 10 MB
50
+ logger.info("[wework] image too large, ready to compress, sz={}".format(sz))
51
+ image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
52
+ logger.info("[wework] image compressed, sz={}".format(fsize(image_storage)))
53
+
54
+ # 将内存缓冲区的指针重置到起始位置
55
+ image_storage.seek(0)
56
+
57
+ # 读取并保存图片
58
+ image = Image.open(image_storage)
59
+ image_path = os.path.join(directory, f"{filename}.png")
60
+ image.save(image_path, "png")
61
+
62
+ return image_path
63
+
64
+
65
+ def download_video(url, filename):
66
+ # 确定保存视频的目录
67
+ directory = os.path.join(os.getcwd(), "tmp")
68
+ # 如果目录不存在,则创建目录
69
+ if not os.path.exists(directory):
70
+ os.makedirs(directory)
71
+
72
+ # 下载视频
73
+ response = requests.get(url, stream=True)
74
+ total_size = 0
75
+
76
+ video_path = os.path.join(directory, f"{filename}.mp4")
77
+
78
+ with open(video_path, 'wb') as f:
79
+ for block in response.iter_content(1024):
80
+ total_size += len(block)
81
+
82
+ # 如果视频的总大小超过30MB (30 * 1024 * 1024 bytes),则停止下载并返回
83
+ if total_size > 30 * 1024 * 1024:
84
+ logger.info("[WX] Video is larger than 30MB, skipping...")
85
+ return None
86
+
87
+ f.write(block)
88
+
89
+ return video_path
90
+
91
+
92
+ def create_message(wework_instance, message, is_group):
93
+ logger.debug(f"正在为{'群聊' if is_group else '单聊'}创建 WeworkMessage")
94
+ cmsg = WeworkMessage(message, wework=wework_instance, is_group=is_group)
95
+ logger.debug(f"cmsg:{cmsg}")
96
+ return cmsg
97
+
98
+
99
+ def handle_message(cmsg, is_group):
100
+ logger.debug(f"准备用 WeworkChannel 处理{'群聊' if is_group else '单聊'}消息")
101
+ if is_group:
102
+ WeworkChannel().handle_group(cmsg)
103
+ else:
104
+ WeworkChannel().handle_single(cmsg)
105
+ logger.debug(f"已用 WeworkChannel 处理完{'群聊' if is_group else '单聊'}消息")
106
+
107
+
108
+ def _check(func):
109
+ def wrapper(self, cmsg: ChatMessage):
110
+ msgId = cmsg.msg_id
111
+ create_time = cmsg.create_time # 消息时间戳
112
+ if create_time is None:
113
+ return func(self, cmsg)
114
+ if int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
115
+ logger.debug("[WX]history message {} skipped".format(msgId))
116
+ return
117
+ return func(self, cmsg)
118
+
119
+ return wrapper
120
+
121
+
122
+ @wework.msg_register(
123
+ [ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_VOICE_MSG])
124
+ def all_msg_handler(wework_instance: ntwork.WeWork, message):
125
+ logger.debug(f"收到消息: {message}")
126
+ if 'data' in message:
127
+ # 首先查找conversation_id,如果没有找到,则查找room_conversation_id
128
+ conversation_id = message['data'].get('conversation_id', message['data'].get('room_conversation_id'))
129
+ if conversation_id is not None:
130
+ is_group = "R:" in conversation_id
131
+ try:
132
+ cmsg = create_message(wework_instance=wework_instance, message=message, is_group=is_group)
133
+ except NotImplementedError as e:
134
+ logger.error(f"[WX]{message.get('MsgId', 'unknown')} 跳过: {e}")
135
+ return None
136
+ delay = random.randint(1, 2)
137
+ timer = threading.Timer(delay, handle_message, args=(cmsg, is_group))
138
+ timer.start()
139
+ else:
140
+ logger.debug("消息数据中无 conversation_id")
141
+ return None
142
+ return None
143
+
144
+
145
+ def accept_friend_with_retries(wework_instance, user_id, corp_id):
146
+ result = wework_instance.accept_friend(user_id, corp_id)
147
+ logger.debug(f'result:{result}')
148
+
149
+
150
+ # @wework.msg_register(ntwork.MT_RECV_FRIEND_MSG)
151
+ # def friend(wework_instance: ntwork.WeWork, message):
152
+ # data = message["data"]
153
+ # user_id = data["user_id"]
154
+ # corp_id = data["corp_id"]
155
+ # logger.info(f"接收到好友请求,消息内容:{data}")
156
+ # delay = random.randint(1, 180)
157
+ # threading.Timer(delay, accept_friend_with_retries, args=(wework_instance, user_id, corp_id)).start()
158
+ #
159
+ # return None
160
+
161
+
162
+ def get_with_retry(get_func, max_retries=5, delay=5):
163
+ retries = 0
164
+ result = None
165
+ while retries < max_retries:
166
+ result = get_func()
167
+ if result:
168
+ break
169
+ logger.warning(f"获取数据失败,重试第{retries + 1}次······")
170
+ retries += 1
171
+ time.sleep(delay) # 等待一段时间后重试
172
+ return result
173
+
174
+
175
+ @singleton
176
+ class WeworkChannel(ChatChannel):
177
+ NOT_SUPPORT_REPLYTYPE = []
178
+
179
+ def __init__(self):
180
+ super().__init__()
181
+
182
+ def startup(self):
183
+ smart = conf().get("wework_smart", True)
184
+ wework.open(smart)
185
+ logger.info("等待登录······")
186
+ wework.wait_login()
187
+ login_info = wework.get_login_info()
188
+ self.user_id = login_info['user_id']
189
+ self.name = login_info['nickname']
190
+ logger.info(f"登录信息:>>>user_id:{self.user_id}>>>>>>>>name:{self.name}")
191
+ logger.info("静默延迟60s,等待客户端刷新数据,请勿进行任何操作······")
192
+ time.sleep(60)
193
+ contacts = get_with_retry(wework.get_external_contacts)
194
+ rooms = get_with_retry(wework.get_rooms)
195
+ directory = os.path.join(os.getcwd(), "tmp")
196
+ if not contacts or not rooms:
197
+ logger.error("获取contacts或rooms失败,程序退出")
198
+ ntwork.exit_()
199
+ os.exit(0)
200
+ if not os.path.exists(directory):
201
+ os.makedirs(directory)
202
+ # 将contacts保存到json文件中
203
+ with open(os.path.join(directory, 'wework_contacts.json'), 'w', encoding='utf-8') as f:
204
+ json.dump(contacts, f, ensure_ascii=False, indent=4)
205
+ with open(os.path.join(directory, 'wework_rooms.json'), 'w', encoding='utf-8') as f:
206
+ json.dump(rooms, f, ensure_ascii=False, indent=4)
207
+ # 创建一个空字典来保存结果
208
+ result = {}
209
+
210
+ # 遍历列表中的每个字典
211
+ for room in rooms['room_list']:
212
+ # 获取聊天室ID
213
+ room_wxid = room['conversation_id']
214
+
215
+ # 获取聊天室成员
216
+ room_members = wework.get_room_members(room_wxid)
217
+
218
+ # 将聊天室成员保存到结果字典中
219
+ result[room_wxid] = room_members
220
+
221
+ # 将结果保存到json文件中
222
+ with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
223
+ json.dump(result, f, ensure_ascii=False, indent=4)
224
+ logger.info("wework程序初始化完成········")
225
+ run.forever()
226
+
227
+ @time_checker
228
+ @_check
229
+ def handle_single(self, cmsg: ChatMessage):
230
+ if cmsg.from_user_id == cmsg.to_user_id:
231
+ # ignore self reply
232
+ return
233
+ if cmsg.ctype == ContextType.VOICE:
234
+ if not conf().get("speech_recognition"):
235
+ return
236
+ logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
237
+ elif cmsg.ctype == ContextType.IMAGE:
238
+ logger.debug("[WX]receive image msg: {}".format(cmsg.content))
239
+ elif cmsg.ctype == ContextType.PATPAT:
240
+ logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
241
+ elif cmsg.ctype == ContextType.TEXT:
242
+ logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
243
+ else:
244
+ logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
245
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
246
+ if context:
247
+ self.produce(context)
248
+
249
+ @time_checker
250
+ @_check
251
+ def handle_group(self, cmsg: ChatMessage):
252
+ if cmsg.ctype == ContextType.VOICE:
253
+ if not conf().get("speech_recognition"):
254
+ return
255
+ logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
256
+ elif cmsg.ctype == ContextType.IMAGE:
257
+ logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
258
+ elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
259
+ logger.debug("[WX]receive note msg: {}".format(cmsg.content))
260
+ elif cmsg.ctype == ContextType.TEXT:
261
+ pass
262
+ else:
263
+ logger.debug("[WX]receive group msg: {}".format(cmsg.content))
264
+ context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
265
+ if context:
266
+ self.produce(context)
267
+
268
+ # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
269
+ def send(self, reply: Reply, context: Context):
270
+ logger.debug(f"context: {context}")
271
+ receiver = context["receiver"]
272
+ actual_user_id = context["msg"].actual_user_id
273
+ if reply.type == ReplyType.TEXT or reply.type == ReplyType.TEXT_:
274
+ match = re.search(r"^@(.*?)\n", reply.content)
275
+ logger.debug(f"match: {match}")
276
+ if match:
277
+ new_content = re.sub(r"^@(.*?)\n", "\n", reply.content)
278
+ at_list = [actual_user_id]
279
+ logger.debug(f"new_content: {new_content}")
280
+ wework.send_room_at_msg(receiver, new_content, at_list)
281
+ else:
282
+ wework.send_text(receiver, reply.content)
283
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
284
+ elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
285
+ wework.send_text(receiver, reply.content)
286
+ logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
287
+ elif reply.type == ReplyType.IMAGE: # 从文件读取图片
288
+ image_storage = reply.content
289
+ image_storage.seek(0)
290
+ # Read data from image_storage
291
+ data = image_storage.read()
292
+ # Create a temporary file
293
+ with tempfile.NamedTemporaryFile(delete=False) as temp:
294
+ temp_path = temp.name
295
+ temp.write(data)
296
+ # Send the image
297
+ wework.send_image(receiver, temp_path)
298
+ logger.info("[WX] sendImage, receiver={}".format(receiver))
299
+ # Remove the temporary file
300
+ os.remove(temp_path)
301
+ elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
302
+ img_url = reply.content
303
+ filename = str(uuid.uuid4())
304
+
305
+ # 调用你的函数,下载图片并保存为本地文件
306
+ image_path = download_and_compress_image(img_url, filename)
307
+
308
+ wework.send_image(receiver, file_path=image_path)
309
+ logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
310
+ elif reply.type == ReplyType.VIDEO_URL:
311
+ video_url = reply.content
312
+ filename = str(uuid.uuid4())
313
+ video_path = download_video(video_url, filename)
314
+
315
+ if video_path is None:
316
+ # 如果视频太大,下载可能会被跳过,此时 video_path 将为 None
317
+ wework.send_text(receiver, "抱歉,视频太大了!!!")
318
+ else:
319
+ wework.send_video(receiver, video_path)
320
+ logger.info("[WX] sendVideo, receiver={}".format(receiver))
321
+ elif reply.type == ReplyType.VOICE:
322
+ current_dir = os.getcwd()
323
+ voice_file = reply.content.split("/")[-1]
324
+ reply.content = os.path.join(current_dir, "tmp", voice_file)
325
+ wework.send_file(receiver, reply.content)
326
+ logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
channel/wework/wework_message.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+ import pilk
7
+
8
+ from bridge.context import ContextType
9
+ from channel.chat_message import ChatMessage
10
+ from common.log import logger
11
+ from ntwork.const import send_type
12
+
13
+
14
+ def get_with_retry(get_func, max_retries=5, delay=5):
15
+ retries = 0
16
+ result = None
17
+ while retries < max_retries:
18
+ result = get_func()
19
+ if result:
20
+ break
21
+ logger.warning(f"获取数据失败,重试第{retries + 1}次······")
22
+ retries += 1
23
+ time.sleep(delay) # 等待一段时间后重试
24
+ return result
25
+
26
+
27
+ def get_room_info(wework, conversation_id):
28
+ logger.debug(f"传入的 conversation_id: {conversation_id}")
29
+ rooms = wework.get_rooms()
30
+ if not rooms or 'room_list' not in rooms:
31
+ logger.error(f"获取群聊信息失败: {rooms}")
32
+ return None
33
+ time.sleep(1)
34
+ logger.debug(f"获取到的群聊信息: {rooms}")
35
+ for room in rooms['room_list']:
36
+ if room['conversation_id'] == conversation_id:
37
+ return room
38
+ return None
39
+
40
+
41
+ def cdn_download(wework, message, file_name):
42
+ data = message["data"]
43
+ aes_key = data["cdn"]["aes_key"]
44
+ file_size = data["cdn"]["size"]
45
+
46
+ # 获取当前工作目录,然后与文件名拼接得到保存路径
47
+ current_dir = os.getcwd()
48
+ save_path = os.path.join(current_dir, "tmp", file_name)
49
+
50
+ # 下载保存图片到本地
51
+ if "url" in data["cdn"].keys() and "auth_key" in data["cdn"].keys():
52
+ url = data["cdn"]["url"]
53
+ auth_key = data["cdn"]["auth_key"]
54
+ # result = wework.wx_cdn_download(url, auth_key, aes_key, file_size, save_path) # ntwork库本身接口有问题,缺失了aes_key这个参数
55
+ """
56
+ 下载wx类型的cdn文件,以https开头
57
+ """
58
+ data = {
59
+ 'url': url,
60
+ 'auth_key': auth_key,
61
+ 'aes_key': aes_key,
62
+ 'size': file_size,
63
+ 'save_path': save_path
64
+ }
65
+ result = wework._WeWork__send_sync(send_type.MT_WXCDN_DOWNLOAD_MSG, data) # 直接用wx_cdn_download的接口内部实现来调用
66
+ elif "file_id" in data["cdn"].keys():
67
+ file_type = 2
68
+ file_id = data["cdn"]["file_id"]
69
+ result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
70
+ else:
71
+ logger.error(f"something is wrong, data: {data}")
72
+ return
73
+
74
+ # 输出下载结果
75
+ logger.debug(f"result: {result}")
76
+
77
+
78
+ def c2c_download_and_convert(wework, message, file_name):
79
+ data = message["data"]
80
+ aes_key = data["cdn"]["aes_key"]
81
+ file_size = data["cdn"]["size"]
82
+ file_type = 5
83
+ file_id = data["cdn"]["file_id"]
84
+
85
+ current_dir = os.getcwd()
86
+ save_path = os.path.join(current_dir, "tmp", file_name)
87
+ result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
88
+ logger.debug(result)
89
+
90
+ # 在下载完SILK文件之后,立即将其转换为WAV文件
91
+ base_name, _ = os.path.splitext(save_path)
92
+ wav_file = base_name + ".wav"
93
+ pilk.silk_to_wav(save_path, wav_file, rate=24000)
94
+
95
+ # 删除SILK文件
96
+ try:
97
+ os.remove(save_path)
98
+ except Exception as e:
99
+ pass
100
+
101
+
102
+ class WeworkMessage(ChatMessage):
103
+ def __init__(self, wework_msg, wework, is_group=False):
104
+ try:
105
+ super().__init__(wework_msg)
106
+ self.msg_id = wework_msg['data'].get('conversation_id', wework_msg['data'].get('room_conversation_id'))
107
+ # 使用.get()防止 'send_time' 键不存在时抛出错误
108
+ self.create_time = wework_msg['data'].get("send_time")
109
+ self.is_group = is_group
110
+ self.wework = wework
111
+
112
+ if wework_msg["type"] == 11041: # 文本消息类型
113
+ if any(substring in wework_msg['data']['content'] for substring in ("该消息类型暂不能展示", "不支持的消息类型")):
114
+ return
115
+ self.ctype = ContextType.TEXT
116
+ self.content = wework_msg['data']['content']
117
+ elif wework_msg["type"] == 11044: # 语音消息类型,需要缓存文件
118
+ file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".silk"
119
+ base_name, _ = os.path.splitext(file_name)
120
+ file_name_2 = base_name + ".wav"
121
+ current_dir = os.getcwd()
122
+ self.ctype = ContextType.VOICE
123
+ self.content = os.path.join(current_dir, "tmp", file_name_2)
124
+ self._prepare_fn = lambda: c2c_download_and_convert(wework, wework_msg, file_name)
125
+ elif wework_msg["type"] == 11042: # 图片消息类型,需要下载文件
126
+ file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".jpg"
127
+ current_dir = os.getcwd()
128
+ self.ctype = ContextType.IMAGE
129
+ self.content = os.path.join(current_dir, "tmp", file_name)
130
+ self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
131
+ elif wework_msg["type"] == 11072: # 新成员入群通知
132
+ self.ctype = ContextType.JOIN_GROUP
133
+ member_list = wework_msg['data']['member_list']
134
+ self.actual_user_nickname = member_list[0]['name']
135
+ self.actual_user_id = member_list[0]['user_id']
136
+ self.content = f"{self.actual_user_nickname}加入了群聊!"
137
+ directory = os.path.join(os.getcwd(), "tmp")
138
+ rooms = get_with_retry(wework.get_rooms)
139
+ if not rooms:
140
+ logger.error("更新群信息失败···")
141
+ else:
142
+ result = {}
143
+ for room in rooms['room_list']:
144
+ # 获取聊天室ID
145
+ room_wxid = room['conversation_id']
146
+
147
+ # 获取聊天室成员
148
+ room_members = wework.get_room_members(room_wxid)
149
+
150
+ # 将聊天室成员保存到结果字典中
151
+ result[room_wxid] = room_members
152
+ with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
153
+ json.dump(result, f, ensure_ascii=False, indent=4)
154
+ logger.info("有新成员加入,已自动更新群成员列表缓存!")
155
+ else:
156
+ raise NotImplementedError(
157
+ "Unsupported message type: Type:{} MsgType:{}".format(wework_msg["type"], wework_msg["MsgType"]))
158
+
159
+ data = wework_msg['data']
160
+ login_info = self.wework.get_login_info()
161
+ logger.debug(f"login_info: {login_info}")
162
+ nickname = f"{login_info['username']}({login_info['nickname']})" if login_info['nickname'] else login_info['username']
163
+ user_id = login_info['user_id']
164
+
165
+ sender_id = data.get('sender')
166
+ conversation_id = data.get('conversation_id')
167
+ sender_name = data.get("sender_name")
168
+
169
+ self.from_user_id = user_id if sender_id == user_id else conversation_id
170
+ self.from_user_nickname = nickname if sender_id == user_id else sender_name
171
+ self.to_user_id = user_id
172
+ self.to_user_nickname = nickname
173
+ self.other_user_nickname = sender_name
174
+ self.other_user_id = conversation_id
175
+
176
+ if self.is_group:
177
+ conversation_id = data.get('conversation_id') or data.get('room_conversation_id')
178
+ self.other_user_id = conversation_id
179
+ if conversation_id:
180
+ room_info = get_room_info(wework=wework, conversation_id=conversation_id)
181
+ self.other_user_nickname = room_info.get('nickname', None) if room_info else None
182
+ at_list = data.get('at_list', [])
183
+ tmp_list = []
184
+ for at in at_list:
185
+ tmp_list.append(at['nickname'])
186
+ at_list = tmp_list
187
+ logger.debug(f"at_list: {at_list}")
188
+ logger.debug(f"nickname: {nickname}")
189
+ self.is_at = False
190
+ if nickname in at_list or login_info['nickname'] in at_list or login_info['username'] in at_list:
191
+ self.is_at = True
192
+ self.at_list = at_list
193
+
194
+ # 检查消息内容是否包含@用户名。处理复制粘贴的消息,这类消息可能不会触发@通知,但内容中可能包含 "@用户名"。
195
+ content = data.get('content', '')
196
+ name = nickname
197
+ pattern = f"@{re.escape(name)}(\u2005|\u0020)"
198
+ if re.search(pattern, content):
199
+ logger.debug(f"Wechaty message {self.msg_id} includes at")
200
+ self.is_at = True
201
+
202
+ if not self.actual_user_id:
203
+ self.actual_user_id = data.get("sender")
204
+ self.actual_user_nickname = sender_name if self.ctype != ContextType.JOIN_GROUP else self.actual_user_nickname
205
+ else:
206
+ logger.error("群聊消息中没有找到 conversation_id 或 room_conversation_id")
207
+
208
+ logger.debug(f"WeworkMessage has been successfully instantiated with message id: {self.msg_id}")
209
+ except Exception as e:
210
+ logger.error(f"在 WeworkMessage 的初始化过程中出现错误:{e}")
211
+ raise e
common/const.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bot_type
2
+ OPEN_AI = "openAI"
3
+ CHATGPT = "chatGPT"
4
+ BAIDU = "baidu"
5
+ XUNFEI = "xunfei"
6
+ CHATGPTONAZURE = "chatGPTOnAzure"
7
+ LINKAI = "linkai"
8
+ CLAUDEAI = "claude"
9
+ QWEN = "qwen"
10
+
11
+ # model
12
+ GPT35 = "gpt-3.5-turbo"
13
+ GPT4 = "gpt-4"
14
+ GPT4_TURBO_PREVIEW = "gpt-4-1106-preview"
15
+ GPT4_VISION_PREVIEW = "gpt-4-vision-preview"
16
+ WHISPER_1 = "whisper-1"
17
+ TTS_1 = "tts-1"
18
+ TTS_1_HD = "tts-1-hd"
19
+
20
+ MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN]
21
+
22
+ # channel
23
+ FEISHU = "feishu"
common/dequeue.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from queue import Full, Queue
2
+ from time import monotonic as time
3
+
4
+
5
+ # add implementation of putleft to Queue
6
+ class Dequeue(Queue):
7
+ def putleft(self, item, block=True, timeout=None):
8
+ with self.not_full:
9
+ if self.maxsize > 0:
10
+ if not block:
11
+ if self._qsize() >= self.maxsize:
12
+ raise Full
13
+ elif timeout is None:
14
+ while self._qsize() >= self.maxsize:
15
+ self.not_full.wait()
16
+ elif timeout < 0:
17
+ raise ValueError("'timeout' must be a non-negative number")
18
+ else:
19
+ endtime = time() + timeout
20
+ while self._qsize() >= self.maxsize:
21
+ remaining = endtime - time()
22
+ if remaining <= 0.0:
23
+ raise Full
24
+ self.not_full.wait(remaining)
25
+ self._putleft(item)
26
+ self.unfinished_tasks += 1
27
+ self.not_empty.notify()
28
+
29
+ def putleft_nowait(self, item):
30
+ return self.putleft(item, block=False)
31
+
32
+ def _putleft(self, item):
33
+ self.queue.appendleft(item)
common/expired_dict.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+
3
+
4
+ class ExpiredDict(dict):
5
+ def __init__(self, expires_in_seconds):
6
+ super().__init__()
7
+ self.expires_in_seconds = expires_in_seconds
8
+
9
+ def __getitem__(self, key):
10
+ value, expiry_time = super().__getitem__(key)
11
+ if datetime.now() > expiry_time:
12
+ del self[key]
13
+ raise KeyError("expired {}".format(key))
14
+ self.__setitem__(key, value)
15
+ return value
16
+
17
+ def __setitem__(self, key, value):
18
+ expiry_time = datetime.now() + timedelta(seconds=self.expires_in_seconds)
19
+ super().__setitem__(key, (value, expiry_time))
20
+
21
+ def get(self, key, default=None):
22
+ try:
23
+ return self[key]
24
+ except KeyError:
25
+ return default
26
+
27
+ def __contains__(self, key):
28
+ try:
29
+ self[key]
30
+ return True
31
+ except KeyError:
32
+ return False
33
+
34
+ def keys(self):
35
+ keys = list(super().keys())
36
+ return [key for key in keys if key in self]
37
+
38
+ def items(self):
39
+ return [(key, self[key]) for key in self.keys()]
40
+
41
+ def __iter__(self):
42
+ return self.keys().__iter__()