lunarflu HF Staff commited on
Commit
97b9b42
·
1 Parent(s): fb515d0

[falcon.py] 1.1 PR

Browse files
Files changed (1) hide show
  1. falcon.py +136 -0
falcon.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+ import os
3
+ import asyncio
4
+ import json
5
+
6
+ HF_TOKEN = os.getenv("HF_TOKEN")
7
+ falcon_userid_threadid_dictionary = {}
8
+ threadid_conversation = {}
9
+ # Instructions are for Falcon-chat and can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat
10
+ instructions = (
11
+ "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a"
12
+ " human user, called User. In the following interactions, User and Falcon will converse in natural language, and"
13
+ " Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built"
14
+ " by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always"
15
+ " attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The"
16
+ " conversation begins."
17
+ )
18
+ falcon_client = Client("HuggingFaceH4/falcon-chat", HF_TOKEN)
19
+
20
+ BOT_USER_ID = 1086256910572986469 if os.getenv("TEST_ENV", False) else 1102236653545861151
21
+ FALCON_CHANNEL_ID = 1079459939405279232 if os.getenv("TEST_ENV", False) else 1119313248056004729
22
+
23
+
24
+ async def waitjob(job):
25
+ while not job.done():
26
+ await asyncio.sleep(0.2)
27
+
28
+
29
+ def falcon_initial_generation(prompt, instructions, thread):
30
+ """Solves two problems at once; 1) The Slash command + job.submit interaction, and 2) the need for job.submit."""
31
+ global threadid_conversation
32
+
33
+ chathistory = falcon_client.predict(fn_index=5)
34
+ temperature = 0.8
35
+ p_nucleus_sampling = 0.9
36
+
37
+ job = falcon_client.submit(prompt, chathistory, instructions, temperature, p_nucleus_sampling, fn_index=1)
38
+ while job.done() is False:
39
+ pass
40
+ else:
41
+ if os.environ.get("TEST_ENV") == "True":
42
+ print("falcon text gen job done")
43
+ file_paths = job.outputs()
44
+ print(file_paths)
45
+ full_generation = file_paths[-1]
46
+ print(full_generation)
47
+ with open(full_generation, "r") as file:
48
+ data = json.load(file)
49
+ print(data)
50
+ output_text = data[-1][-1]
51
+ threadid_conversation[thread.id] = full_generation
52
+ if len(output_text) > 1300:
53
+ output_text = output_text[:1300] + "...\nTruncating response to 2000 characters due to discord api limits."
54
+ if os.environ.get("TEST_ENV") == "True":
55
+ print(output_text)
56
+ return output_text
57
+
58
+
59
+ async def try_falcon(interaction, prompt):
60
+ """Generates text based on a given prompt"""
61
+ try:
62
+ global falcon_userid_threadid_dictionary # tracks userid-thread existence
63
+ global threadid_conversation
64
+
65
+ if interaction.user.id != BOT_USER_ID:
66
+ if interaction.channel.id == FALCON_CHANNEL_ID:
67
+ if os.environ.get("TEST_ENV") == "True":
68
+ print("Safetychecks passed for try_falcon")
69
+ await interaction.response.send_message("Working on it!")
70
+ channel = interaction.channel
71
+ message = await channel.send("Creating thread...")
72
+ thread = await message.create_thread(name=prompt, auto_archive_duration=60) # interaction.user
73
+ await thread.send(
74
+ "[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; The Falcon model and system"
75
+ " prompt can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat]"
76
+ )
77
+
78
+ if os.environ.get("TEST_ENV") == "True":
79
+ print("Running falcon_initial_generation...")
80
+ loop = asyncio.get_running_loop()
81
+ output_text = await loop.run_in_executor(None, falcon_initial_generation, prompt, instructions, thread)
82
+ falcon_userid_threadid_dictionary[thread.id] = interaction.user.id
83
+
84
+ await thread.send(output_text)
85
+ except Exception as e:
86
+ print(f"try_falcon Error: {e}")
87
+
88
+
89
+ async def continue_falcon(message):
90
+ """Continues a given conversation based on chathistory"""
91
+ try:
92
+ if not message.author.bot:
93
+ global falcon_userid_threadid_dictionary # tracks userid-thread existence
94
+ if message.channel.id in falcon_userid_threadid_dictionary: # is this a valid thread?
95
+ if (
96
+ falcon_userid_threadid_dictionary[message.channel.id] == message.author.id
97
+ ): # more than that - is this specifically the right user for this thread?
98
+ if os.environ.get("TEST_ENV") == "True":
99
+ print("Safetychecks passed for continue_falcon")
100
+ global instructions
101
+ global threadid_conversation
102
+ await message.add_reaction("🔁")
103
+
104
+ prompt = message.content
105
+ chathistory = threadid_conversation[message.channel.id]
106
+ temperature = 0.8
107
+ p_nucleus_sampling = 0.9
108
+
109
+ if os.environ.get("TEST_ENV") == "True":
110
+ print("Running falcon_client.submit")
111
+ job = falcon_client.submit(
112
+ prompt,
113
+ chathistory,
114
+ instructions,
115
+ temperature,
116
+ p_nucleus_sampling,
117
+ fn_index=1,
118
+ )
119
+ await waitjob(job)
120
+ if os.environ.get("TEST_ENV") == "True":
121
+ print("Continue_falcon job done")
122
+ file_paths = job.outputs()
123
+ full_generation = file_paths[-1]
124
+ with open(full_generation, "r") as file:
125
+ data = json.load(file)
126
+ output_text = data[-1][-1]
127
+ threadid_conversation[message.channel.id] = full_generation # overwrite the old file
128
+ if len(output_text) > 1300:
129
+ output_text = (
130
+ output_text[:1300]
131
+ + "...\nTruncating response to 2000 characters due to discord api limits."
132
+ )
133
+ await message.reply(output_text)
134
+ except Exception as e:
135
+ print(f"continue_falcon Error: {e}")
136
+ await message.reply(f"Error: {e} <@811235357663297546> (continue_falcon error)")