File size: 3,061 Bytes
97b9b42
58c201b
d21b294
 
 
 
58c201b
97b9b42
 
d21b294
97b9b42
 
 
 
 
d21b294
 
 
 
 
97b9b42
 
 
 
d21b294
 
 
 
 
 
 
 
97b9b42
 
58c201b
 
97b9b42
 
58c201b
f82ba90
 
 
 
 
97b9b42
 
 
 
d21b294
 
 
 
 
 
 
 
 
 
 
97b9b42
d21b294
97b9b42
 
d21b294
97b9b42
 
d21b294
 
 
 
 
 
 
 
 
 
 
97b9b42
 
58c201b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import asyncio
import os
from typing import Optional

import gradio_client as grc
from gradio_client.utils import QueueError


HF_TOKEN = os.getenv("HF_TOKEN")


BOT_USER_ID = 1086256910572986469 if os.getenv("TEST_ENV", False) else 1102236653545861151
FALCON_CHANNEL_ID = 1079459939405279232 if os.getenv("TEST_ENV", False) else 1119313248056004729


thread_to_client = {}
thread_to_user = {}


async def wait(job):
    while not job.done():
        await asyncio.sleep(0.2)


def get_client(session: Optional[str] = None) -> grc.Client:
    client = grc.Client("huggingface-projects/falcon-180b-discord", hf_token=os.getenv("HF_TOKEN"))
    if session:
        client.session_hash = session
    return client


async def falcon_chat(ctx, prompt):
    """Generates text based on a given prompt"""
    try:
        if ctx.author.id != BOT_USER_ID:
            if ctx.channel.id == FALCON_CHANNEL_ID:
                if os.environ.get("TEST_ENV") == "True":
                    print("Safetychecks passed for try_falcon")
                message = await ctx.send(f"**{prompt}** - {ctx.author.mention}")
                if len(prompt) > 99:
                    small_prompt = prompt[:99]
                else:
                    small_prompt = prompt
                thread = await message.create_thread(name=small_prompt, auto_archive_duration=60)  # interaction.user

                if os.environ.get("TEST_ENV") == "True":
                    print("Running falcon_initial_generation...")
                loop = asyncio.get_running_loop()
                client = await loop.run_in_executor(None, get_client, None)
                job = client.submit(prompt, api_name="/chat")
                await wait(job)
                try:
                    job.result()
                    response = job.outputs()[-1]
                    await thread.send(response)
                    thread_to_client[thread.id] = client
                    thread_to_user[thread.id] = ctx.author.id
                except QueueError:
                    await thread.send("The gradio space powering this bot is really busy! Please try again later!")
    except Exception as e:
        print(f"chat (180B) Error: {e}")


async def continue_chat(message):
    """Continues a given conversation based on chathistory"""
    try:
        if message.channel.id in thread_to_user:
            if thread_to_user[message.channel.id] == message.author.id:
                client = thread_to_client[message.channel.id]
                job = client.submit(message.content, api_name="/chat")
                await wait(job)
                try:
                    job.result()
                    response = job.outputs()[-1]
                    await message.reply(response)
                except QueueError:
                    await message.reply("The gradio space powering this bot is really busy! Please try again later!")
    except Exception as e:
        print(f"continue_falcon Error: {e}")
        await message.reply(f"Error: {e} <@811235357663297546> (continue_falcon error)")