[falcon.py] 1.1 PR
Browse files
falcon.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gradio_client import Client
|
2 |
+
import os
|
3 |
+
import asyncio
|
4 |
+
import json
|
5 |
+
|
6 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
7 |
+
falcon_userid_threadid_dictionary = {}
|
8 |
+
threadid_conversation = {}
|
9 |
+
# Instructions are for Falcon-chat and can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat
|
10 |
+
instructions = (
|
11 |
+
"The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a"
|
12 |
+
" human user, called User. In the following interactions, User and Falcon will converse in natural language, and"
|
13 |
+
" Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built"
|
14 |
+
" by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always"
|
15 |
+
" attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The"
|
16 |
+
" conversation begins."
|
17 |
+
)
|
18 |
+
falcon_client = Client("HuggingFaceH4/falcon-chat", HF_TOKEN)
|
19 |
+
|
20 |
+
BOT_USER_ID = 1086256910572986469 if os.getenv("TEST_ENV", False) else 1102236653545861151
|
21 |
+
FALCON_CHANNEL_ID = 1079459939405279232 if os.getenv("TEST_ENV", False) else 1119313248056004729
|
22 |
+
|
23 |
+
|
24 |
+
async def waitjob(job):
|
25 |
+
while not job.done():
|
26 |
+
await asyncio.sleep(0.2)
|
27 |
+
|
28 |
+
|
29 |
+
def falcon_initial_generation(prompt, instructions, thread):
|
30 |
+
"""Solves two problems at once; 1) The Slash command + job.submit interaction, and 2) the need for job.submit."""
|
31 |
+
global threadid_conversation
|
32 |
+
|
33 |
+
chathistory = falcon_client.predict(fn_index=5)
|
34 |
+
temperature = 0.8
|
35 |
+
p_nucleus_sampling = 0.9
|
36 |
+
|
37 |
+
job = falcon_client.submit(prompt, chathistory, instructions, temperature, p_nucleus_sampling, fn_index=1)
|
38 |
+
while job.done() is False:
|
39 |
+
pass
|
40 |
+
else:
|
41 |
+
if os.environ.get("TEST_ENV") == "True":
|
42 |
+
print("falcon text gen job done")
|
43 |
+
file_paths = job.outputs()
|
44 |
+
print(file_paths)
|
45 |
+
full_generation = file_paths[-1]
|
46 |
+
print(full_generation)
|
47 |
+
with open(full_generation, "r") as file:
|
48 |
+
data = json.load(file)
|
49 |
+
print(data)
|
50 |
+
output_text = data[-1][-1]
|
51 |
+
threadid_conversation[thread.id] = full_generation
|
52 |
+
if len(output_text) > 1300:
|
53 |
+
output_text = output_text[:1300] + "...\nTruncating response to 2000 characters due to discord api limits."
|
54 |
+
if os.environ.get("TEST_ENV") == "True":
|
55 |
+
print(output_text)
|
56 |
+
return output_text
|
57 |
+
|
58 |
+
|
59 |
+
async def try_falcon(interaction, prompt):
|
60 |
+
"""Generates text based on a given prompt"""
|
61 |
+
try:
|
62 |
+
global falcon_userid_threadid_dictionary # tracks userid-thread existence
|
63 |
+
global threadid_conversation
|
64 |
+
|
65 |
+
if interaction.user.id != BOT_USER_ID:
|
66 |
+
if interaction.channel.id == FALCON_CHANNEL_ID:
|
67 |
+
if os.environ.get("TEST_ENV") == "True":
|
68 |
+
print("Safetychecks passed for try_falcon")
|
69 |
+
await interaction.response.send_message("Working on it!")
|
70 |
+
channel = interaction.channel
|
71 |
+
message = await channel.send("Creating thread...")
|
72 |
+
thread = await message.create_thread(name=prompt, auto_archive_duration=60) # interaction.user
|
73 |
+
await thread.send(
|
74 |
+
"[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; The Falcon model and system"
|
75 |
+
" prompt can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat]"
|
76 |
+
)
|
77 |
+
|
78 |
+
if os.environ.get("TEST_ENV") == "True":
|
79 |
+
print("Running falcon_initial_generation...")
|
80 |
+
loop = asyncio.get_running_loop()
|
81 |
+
output_text = await loop.run_in_executor(None, falcon_initial_generation, prompt, instructions, thread)
|
82 |
+
falcon_userid_threadid_dictionary[thread.id] = interaction.user.id
|
83 |
+
|
84 |
+
await thread.send(output_text)
|
85 |
+
except Exception as e:
|
86 |
+
print(f"try_falcon Error: {e}")
|
87 |
+
|
88 |
+
|
89 |
+
async def continue_falcon(message):
|
90 |
+
"""Continues a given conversation based on chathistory"""
|
91 |
+
try:
|
92 |
+
if not message.author.bot:
|
93 |
+
global falcon_userid_threadid_dictionary # tracks userid-thread existence
|
94 |
+
if message.channel.id in falcon_userid_threadid_dictionary: # is this a valid thread?
|
95 |
+
if (
|
96 |
+
falcon_userid_threadid_dictionary[message.channel.id] == message.author.id
|
97 |
+
): # more than that - is this specifically the right user for this thread?
|
98 |
+
if os.environ.get("TEST_ENV") == "True":
|
99 |
+
print("Safetychecks passed for continue_falcon")
|
100 |
+
global instructions
|
101 |
+
global threadid_conversation
|
102 |
+
await message.add_reaction("🔁")
|
103 |
+
|
104 |
+
prompt = message.content
|
105 |
+
chathistory = threadid_conversation[message.channel.id]
|
106 |
+
temperature = 0.8
|
107 |
+
p_nucleus_sampling = 0.9
|
108 |
+
|
109 |
+
if os.environ.get("TEST_ENV") == "True":
|
110 |
+
print("Running falcon_client.submit")
|
111 |
+
job = falcon_client.submit(
|
112 |
+
prompt,
|
113 |
+
chathistory,
|
114 |
+
instructions,
|
115 |
+
temperature,
|
116 |
+
p_nucleus_sampling,
|
117 |
+
fn_index=1,
|
118 |
+
)
|
119 |
+
await waitjob(job)
|
120 |
+
if os.environ.get("TEST_ENV") == "True":
|
121 |
+
print("Continue_falcon job done")
|
122 |
+
file_paths = job.outputs()
|
123 |
+
full_generation = file_paths[-1]
|
124 |
+
with open(full_generation, "r") as file:
|
125 |
+
data = json.load(file)
|
126 |
+
output_text = data[-1][-1]
|
127 |
+
threadid_conversation[message.channel.id] = full_generation # overwrite the old file
|
128 |
+
if len(output_text) > 1300:
|
129 |
+
output_text = (
|
130 |
+
output_text[:1300]
|
131 |
+
+ "...\nTruncating response to 2000 characters due to discord api limits."
|
132 |
+
)
|
133 |
+
await message.reply(output_text)
|
134 |
+
except Exception as e:
|
135 |
+
print(f"continue_falcon Error: {e}")
|
136 |
+
await message.reply(f"Error: {e} <@811235357663297546> (continue_falcon error)")
|