|
import asyncio |
|
import json |
|
import os |
|
|
|
from gradio_client import Client |
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
codellama = Client("https://huggingface-projects-codellama-13b-chat.hf.space/", HF_TOKEN) |
|
|
|
BOT_USER_ID = 1102236653545861151 |
|
CODELLAMA_CHANNEL_ID = 1147210106321256508 |
|
|
|
|
|
codellama_threadid_userid_dictionary = {} |
|
codellama_threadid_conversation = {} |
|
|
|
|
|
def codellama_initial_generation(prompt, thread): |
|
"""job.submit inside of run_in_executor = more consistent bot behavior""" |
|
global codellama_threadid_conversation |
|
|
|
chat_history = f"{thread.id}.json" |
|
conversation = [] |
|
with open(chat_history, "w") as json_file: |
|
json.dump(conversation, json_file) |
|
|
|
job = codellama.submit(prompt, chat_history, fn_index=0) |
|
|
|
while job.done() is False: |
|
pass |
|
else: |
|
result = job.outputs()[-1] |
|
with open(result, "r") as json_file: |
|
data = json.load(json_file) |
|
response = data[-1][-1] |
|
conversation.append((prompt, response)) |
|
with open(chat_history, "w") as json_file: |
|
json.dump(conversation, json_file) |
|
|
|
codellama_threadid_conversation[thread.id] = chat_history |
|
if len(response) > 1300: |
|
response = response[:1300] + "...\nTruncating response due to discord api limits." |
|
return response |
|
|
|
|
|
async def try_codellama(ctx, prompt): |
|
"""Generates text based on a given prompt""" |
|
try: |
|
global codellama_threadid_userid_dictionary |
|
global codellama_threadid_conversation |
|
|
|
if ctx.author.id != BOT_USER_ID: |
|
if ctx.channel.id == CODELLAMA_CHANNEL_ID: |
|
message = await ctx.send(f"**{prompt}** - {ctx.author.mention}") |
|
if len(prompt) > 99: |
|
small_prompt = prompt[:99] |
|
else: |
|
small_prompt = prompt |
|
thread = await message.create_thread(name=small_prompt, auto_archive_duration=60) |
|
|
|
loop = asyncio.get_running_loop() |
|
output_code = await loop.run_in_executor(None, codellama_initial_generation, prompt, thread) |
|
codellama_threadid_userid_dictionary[thread.id] = ctx.author.id |
|
|
|
print(output_code) |
|
await thread.send(output_code) |
|
except Exception as e: |
|
print(f"try_codellama Error: {e}") |
|
await ctx.send(f"Error: {e} <@811235357663297546> (try_codellama error)") |
|
|
|
|
|
async def continue_codellama(message): |
|
"""Continues a given conversation based on chat_history""" |
|
try: |
|
if not message.author.bot: |
|
global codellama_threadid_userid_dictionary |
|
if message.channel.id in codellama_threadid_userid_dictionary: |
|
if codellama_threadid_userid_dictionary[message.channel.id] == message.author.id: |
|
print("Safetychecks passed for continue_codellama") |
|
global codellama_threadid_conversation |
|
|
|
prompt = message.content |
|
chat_history = codellama_threadid_conversation[message.channel.id] |
|
|
|
|
|
with open(chat_history, "r") as json_file: |
|
conversation = json.load(json_file) |
|
total_characters = 0 |
|
for item in conversation: |
|
for string in item: |
|
total_characters += len(string) |
|
|
|
if total_characters < 15000: |
|
if os.environ.get("TEST_ENV") == "True": |
|
print("Running codellama.submit") |
|
job = codellama.submit(prompt, chat_history, fn_index=0) |
|
while job.done() is False: |
|
pass |
|
else: |
|
if os.environ.get("TEST_ENV") == "True": |
|
print("Continue_codellama job done") |
|
|
|
result = job.outputs()[-1] |
|
with open(result, "r") as json_file: |
|
data = json.load(json_file) |
|
response = data[-1][-1] |
|
|
|
with open(chat_history, "r") as json_file: |
|
conversation = json.load(json_file) |
|
|
|
conversation.append((prompt, response)) |
|
|
|
|
|
with open(chat_history, "w") as json_file: |
|
json.dump(conversation, json_file) |
|
if os.environ.get("TEST_ENV") == "True": |
|
print(prompt) |
|
print(response) |
|
print(conversation) |
|
print(chat_history) |
|
|
|
codellama_threadid_conversation[message.channel.id] = chat_history |
|
if len(response) > 1300: |
|
response = response[:1300] + "...\nTruncating response due to discord api limits." |
|
|
|
await message.reply(response) |
|
|
|
total_characters = 0 |
|
for item in conversation: |
|
for string in item: |
|
total_characters += len(string) |
|
|
|
if total_characters >= 15000: |
|
await message.reply("Conversation ending due to length, feel free to start a new one!") |
|
|
|
except Exception as e: |
|
print(f"continue_codellama Error: {e}") |
|
await message.reply(f"Error: {e} <@811235357663297546> (continue_codellama error)") |
|
|