File size: 6,239 Bytes
97b9b42
 
58c201b
 
 
97b9b42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58c201b
97b9b42
 
 
 
 
58c201b
 
97b9b42
 
58c201b
f82ba90
 
 
 
 
97b9b42
 
 
 
 
 
 
 
 
58c201b
97b9b42
 
 
 
 
 
 
 
 
 
 
 
58c201b
 
97b9b42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58c201b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import asyncio
import json
import os

from gradio_client import Client

HF_TOKEN = os.getenv("HF_TOKEN")
falcon_userid_threadid_dictionary = {}
threadid_conversation = {}
# Instructions are for Falcon-chat and can be found here:  https://huggingface.co/spaces/HuggingFaceH4/falcon-chat
instructions = (
    "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a"
    " human user, called User. In the following interactions, User and Falcon will converse in natural language, and"
    " Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. Falcon was built"
    " by the Technology Innovation Institute in Abu Dhabi. Falcon will never decline to answer a question, and always"
    " attempts to give an answer that User would be satisfied with. It knows a lot, and always tells the truth. The"
    " conversation begins."
)
falcon_client = Client("HuggingFaceH4/falcon-chat", HF_TOKEN)

BOT_USER_ID = 1086256910572986469 if os.getenv("TEST_ENV", False) else 1102236653545861151
FALCON_CHANNEL_ID = 1079459939405279232 if os.getenv("TEST_ENV", False) else 1119313248056004729


async def waitjob(job):
    while not job.done():
        await asyncio.sleep(0.2)


def falcon_initial_generation(prompt, instructions, thread):
    """Solves two problems at once; 1) The Slash command + job.submit interaction, and 2) the need for job.submit."""
    global threadid_conversation

    chathistory = falcon_client.predict(fn_index=5)
    temperature = 0.8
    p_nucleus_sampling = 0.9

    job = falcon_client.submit(prompt, chathistory, instructions, temperature, p_nucleus_sampling, fn_index=1)
    while job.done() is False:
        pass
    else:
        if os.environ.get("TEST_ENV") == "True":
            print("falcon text gen job done")
        file_paths = job.outputs()
        print(file_paths)
        full_generation = file_paths[-1]
        print(full_generation)
        with open(full_generation, "r") as file:
            data = json.load(file)
            print(data)
        output_text = data[-1][-1]
        threadid_conversation[thread.id] = full_generation
        if len(output_text) > 1300:
            output_text = output_text[:1300] + "...\nTruncating response to 2000 characters due to discord api limits."
        if os.environ.get("TEST_ENV") == "True":
            print(output_text)
        return output_text


async def try_falcon(ctx, prompt):
    """Generates text based on a given prompt"""
    try:
        global falcon_userid_threadid_dictionary  # tracks userid-thread existence
        global threadid_conversation

        if ctx.author.id != BOT_USER_ID:
            if ctx.channel.id == FALCON_CHANNEL_ID:
                if os.environ.get("TEST_ENV") == "True":
                    print("Safetychecks passed for try_falcon")
                message = await ctx.send(f"**{prompt}** - {ctx.author.mention}")
                if len(prompt) > 99:
                    small_prompt = prompt[:99]
                else:
                    small_prompt = prompt
                thread = await message.create_thread(name=small_prompt, auto_archive_duration=60)  # interaction.user
                await thread.send(
                    "[DISCLAIMER: HuggingBot is a **highly experimental** beta feature; The Falcon model and system"
                    " prompt can be found here: https://huggingface.co/spaces/HuggingFaceH4/falcon-chat]"
                )

                if os.environ.get("TEST_ENV") == "True":
                    print("Running falcon_initial_generation...")
                loop = asyncio.get_running_loop()
                output_text = await loop.run_in_executor(None, falcon_initial_generation, prompt, instructions, thread)
                falcon_userid_threadid_dictionary[thread.id] = ctx.author.id

                await thread.send(output_text)
    except Exception as e:
        print(f"try_falcon Error: {e}")


async def continue_falcon(message):
    """Continues a given conversation based on chathistory"""
    try:
        if not message.author.bot:
            global falcon_userid_threadid_dictionary  # tracks userid-thread existence
            if message.channel.id in falcon_userid_threadid_dictionary:  # is this a valid thread?
                if falcon_userid_threadid_dictionary[message.channel.id] == message.author.id:
                    print("Safetychecks passed for continue_falcon")
                    global instructions
                    global threadid_conversation
                    await message.add_reaction("πŸ”")

                    prompt = message.content
                    chathistory = threadid_conversation[message.channel.id]
                    temperature = 0.8
                    p_nucleus_sampling = 0.9

                    if os.environ.get("TEST_ENV") == "True":
                        print("Running falcon_client.submit")
                    job = falcon_client.submit(
                        prompt,
                        chathistory,
                        instructions,
                        temperature,
                        p_nucleus_sampling,
                        fn_index=1,
                    )
                    await waitjob(job)
                    if os.environ.get("TEST_ENV") == "True":
                        print("Continue_falcon job done")
                    file_paths = job.outputs()
                    full_generation = file_paths[-1]
                    with open(full_generation, "r") as file:
                        data = json.load(file)
                        output_text = data[-1][-1]
                    threadid_conversation[message.channel.id] = full_generation  # overwrite the old file
                    if len(output_text) > 1300:
                        output_text = (
                            output_text[:1300]
                            + "...\nTruncating response to 2000 characters due to discord api limits."
                        )
                    await message.reply(output_text)
    except Exception as e:
        print(f"continue_falcon Error: {e}")
        await message.reply(f"Error: {e} <@811235357663297546> (continue_falcon error)")