File size: 5,920 Bytes
e603607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import asyncio
import json
import os

from gradio_client import Client

HF_TOKEN = os.getenv("HF_TOKEN")

codellama = Client("https://huggingface-projects-codellama-13b-chat.hf.space/", HF_TOKEN)

BOT_USER_ID = 1102236653545861151  # real
CODELLAMA_CHANNEL_ID = 1147210106321256508  # real


codellama_threadid_userid_dictionary = {}
codellama_threadid_conversation = {}


def codellama_initial_generation(prompt, thread):
    """job.submit inside of run_in_executor = more consistent bot behavior"""
    global codellama_threadid_conversation

    chat_history = f"{thread.id}.json"
    conversation = []
    with open(chat_history, "w") as json_file:
        json.dump(conversation, json_file)

    job = codellama.submit(prompt, chat_history, fn_index=0)

    while job.done() is False:
        pass
    else:
        result = job.outputs()[-1]
        with open(result, "r") as json_file:
            data = json.load(json_file)
        response = data[-1][-1]
        conversation.append((prompt, response))
        with open(chat_history, "w") as json_file:
            json.dump(conversation, json_file)

        codellama_threadid_conversation[thread.id] = chat_history
        if len(response) > 1300:
            response = response[:1300] + "...\nTruncating response due to discord api limits."
        return response


async def try_codellama(ctx, prompt):
    """Generates text based on a given prompt"""
    try:
        global codellama_threadid_userid_dictionary  # tracks userid-thread existence
        global codellama_threadid_conversation

        if ctx.author.id != BOT_USER_ID:
            if ctx.channel.id == CODELLAMA_CHANNEL_ID:
                message = await ctx.send(f"**{prompt}** - {ctx.author.mention}")
                if len(prompt) > 99:
                    small_prompt = prompt[:99]
                else:
                    small_prompt = prompt
                thread = await message.create_thread(name=small_prompt, auto_archive_duration=60)

                loop = asyncio.get_running_loop()
                output_code = await loop.run_in_executor(None, codellama_initial_generation, prompt, thread)
                codellama_threadid_userid_dictionary[thread.id] = ctx.author.id

                print(output_code)
                await thread.send(output_code)
    except Exception as e:
        print(f"try_codellama Error: {e}")
        await ctx.send(f"Error: {e} <@811235357663297546> (try_codellama error)")


async def continue_codellama(message):
    """Continues a given conversation based on chat_history"""
    try:
        if not message.author.bot:
            global codellama_threadid_userid_dictionary  # tracks userid-thread existence
            if message.channel.id in codellama_threadid_userid_dictionary:  # is this a valid thread?
                if codellama_threadid_userid_dictionary[message.channel.id] == message.author.id:
                    print("Safetychecks passed for continue_codellama")
                    global codellama_threadid_conversation

                    prompt = message.content
                    chat_history = codellama_threadid_conversation[message.channel.id]

                    # Check to see if conversation is ongoing or ended (>15000 characters)
                    with open(chat_history, "r") as json_file:
                        conversation = json.load(json_file)
                    total_characters = 0
                    for item in conversation:
                        for string in item:
                            total_characters += len(string)

                    if total_characters < 15000:
                        if os.environ.get("TEST_ENV") == "True":
                            print("Running codellama.submit")
                        job = codellama.submit(prompt, chat_history, fn_index=0)
                        while job.done() is False:
                            pass
                        else:
                            if os.environ.get("TEST_ENV") == "True":
                                print("Continue_codellama job done")

                            result = job.outputs()[-1]
                            with open(result, "r") as json_file:
                                data = json.load(json_file)
                            response = data[-1][-1]

                            with open(chat_history, "r") as json_file:
                                conversation = json.load(json_file)

                            conversation.append((prompt, response))
                            # now we have prompt, response, and the newly updated full conversation

                            with open(chat_history, "w") as json_file:
                                json.dump(conversation, json_file)
                            if os.environ.get("TEST_ENV") == "True":
                                print(prompt)
                                print(response)
                                print(conversation)
                                print(chat_history)

                            codellama_threadid_conversation[message.channel.id] = chat_history
                            if len(response) > 1300:
                                response = response[:1300] + "...\nTruncating response due to discord api limits."

                            await message.reply(response)

                            total_characters = 0
                            for item in conversation:
                                for string in item:
                                    total_characters += len(string)

                            if total_characters >= 15000:
                                await message.reply("Conversation ending due to length, feel free to start a new one!")

    except Exception as e:
        print(f"continue_codellama Error: {e}")
        await message.reply(f"Error: {e} <@811235357663297546> (continue_codellama error)")