lunarflu HF Staff commited on
Commit
e603607
·
1 Parent(s): b7db66e

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (2) hide show
  1. app.py +16 -1
  2. codellama.py +139 -0
app.py CHANGED
@@ -9,7 +9,7 @@ from discord import app_commands
9
  from discord.ext import commands
10
  from falcon import continue_falcon, try_falcon
11
  from musicgen import music_create
12
-
13
 
14
  # HF GUILD SETTINGS
15
  MY_GUILD_ID = 1077674588122648679 if os.getenv("TEST_ENV", False) else 879548962464493619
@@ -51,11 +51,26 @@ async def falcon(ctx, prompt: str):
51
  print(f"Error: {e}")
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  @client.event
55
  async def on_message(message):
56
  """Checks channel and continues Falcon conversation if it's the right Discord Thread"""
57
  try:
58
  await continue_falcon(message)
 
59
  except Exception as e:
60
  print(f"Error: {e}")
61
 
 
9
  from discord.ext import commands
10
  from falcon import continue_falcon, try_falcon
11
  from musicgen import music_create
12
+ from codellama import continue_codellama, try_codellama
13
 
14
  # HF GUILD SETTINGS
15
  MY_GUILD_ID = 1077674588122648679 if os.getenv("TEST_ENV", False) else 879548962464493619
 
51
  print(f"Error: {e}")
52
 
53
 
54
+ @client.hybrid_command(
55
+ name="codellama",
56
+ with_app_command=True,
57
+ description="Enter a prompt to generate code!",
58
+ )
59
+ @app_commands.guilds(MY_GUILD)
60
+ async def codellama(ctx, prompt: str):
61
+ """Audioldm2 generation"""
62
+ try:
63
+ await try_codellama(ctx, prompt)
64
+ except Exception as e:
65
+ print(f"Error: (app.py){e}")
66
+
67
+
68
  @client.event
69
  async def on_message(message):
70
  """Checks channel and continues Falcon conversation if it's the right Discord Thread"""
71
  try:
72
  await continue_falcon(message)
73
+ await continue_codellama(message)
74
  except Exception as e:
75
  print(f"Error: {e}")
76
 
codellama.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+
5
+ from gradio_client import Client
6
+
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
+
9
+ codellama = Client("https://huggingface-projects-codellama-13b-chat.hf.space/", HF_TOKEN)
10
+
11
+ BOT_USER_ID = 1102236653545861151 # real
12
+ CODELLAMA_CHANNEL_ID = 1147210106321256508 # real
13
+
14
+
15
+ codellama_threadid_userid_dictionary = {}
16
+ codellama_threadid_conversation = {}
17
+
18
+
19
+ def codellama_initial_generation(prompt, thread):
20
+ """job.submit inside of run_in_executor = more consistent bot behavior"""
21
+ global codellama_threadid_conversation
22
+
23
+ chat_history = f"{thread.id}.json"
24
+ conversation = []
25
+ with open(chat_history, "w") as json_file:
26
+ json.dump(conversation, json_file)
27
+
28
+ job = codellama.submit(prompt, chat_history, fn_index=0)
29
+
30
+ while job.done() is False:
31
+ pass
32
+ else:
33
+ result = job.outputs()[-1]
34
+ with open(result, "r") as json_file:
35
+ data = json.load(json_file)
36
+ response = data[-1][-1]
37
+ conversation.append((prompt, response))
38
+ with open(chat_history, "w") as json_file:
39
+ json.dump(conversation, json_file)
40
+
41
+ codellama_threadid_conversation[thread.id] = chat_history
42
+ if len(response) > 1300:
43
+ response = response[:1300] + "...\nTruncating response due to discord api limits."
44
+ return response
45
+
46
+
47
+ async def try_codellama(ctx, prompt):
48
+ """Generates text based on a given prompt"""
49
+ try:
50
+ global codellama_threadid_userid_dictionary # tracks userid-thread existence
51
+ global codellama_threadid_conversation
52
+
53
+ if ctx.author.id != BOT_USER_ID:
54
+ if ctx.channel.id == CODELLAMA_CHANNEL_ID:
55
+ message = await ctx.send(f"**{prompt}** - {ctx.author.mention}")
56
+ if len(prompt) > 99:
57
+ small_prompt = prompt[:99]
58
+ else:
59
+ small_prompt = prompt
60
+ thread = await message.create_thread(name=small_prompt, auto_archive_duration=60)
61
+
62
+ loop = asyncio.get_running_loop()
63
+ output_code = await loop.run_in_executor(None, codellama_initial_generation, prompt, thread)
64
+ codellama_threadid_userid_dictionary[thread.id] = ctx.author.id
65
+
66
+ print(output_code)
67
+ await thread.send(output_code)
68
+ except Exception as e:
69
+ print(f"try_codellama Error: {e}")
70
+ await ctx.send(f"Error: {e} <@811235357663297546> (try_codellama error)")
71
+
72
+
73
+ async def continue_codellama(message):
74
+ """Continues a given conversation based on chat_history"""
75
+ try:
76
+ if not message.author.bot:
77
+ global codellama_threadid_userid_dictionary # tracks userid-thread existence
78
+ if message.channel.id in codellama_threadid_userid_dictionary: # is this a valid thread?
79
+ if codellama_threadid_userid_dictionary[message.channel.id] == message.author.id:
80
+ print("Safetychecks passed for continue_codellama")
81
+ global codellama_threadid_conversation
82
+
83
+ prompt = message.content
84
+ chat_history = codellama_threadid_conversation[message.channel.id]
85
+
86
+ # Check to see if conversation is ongoing or ended (>15000 characters)
87
+ with open(chat_history, "r") as json_file:
88
+ conversation = json.load(json_file)
89
+ total_characters = 0
90
+ for item in conversation:
91
+ for string in item:
92
+ total_characters += len(string)
93
+
94
+ if total_characters < 15000:
95
+ if os.environ.get("TEST_ENV") == "True":
96
+ print("Running codellama.submit")
97
+ job = codellama.submit(prompt, chat_history, fn_index=0)
98
+ while job.done() is False:
99
+ pass
100
+ else:
101
+ if os.environ.get("TEST_ENV") == "True":
102
+ print("Continue_codellama job done")
103
+
104
+ result = job.outputs()[-1]
105
+ with open(result, "r") as json_file:
106
+ data = json.load(json_file)
107
+ response = data[-1][-1]
108
+
109
+ with open(chat_history, "r") as json_file:
110
+ conversation = json.load(json_file)
111
+
112
+ conversation.append((prompt, response))
113
+ # now we have prompt, response, and the newly updated full conversation
114
+
115
+ with open(chat_history, "w") as json_file:
116
+ json.dump(conversation, json_file)
117
+ if os.environ.get("TEST_ENV") == "True":
118
+ print(prompt)
119
+ print(response)
120
+ print(conversation)
121
+ print(chat_history)
122
+
123
+ codellama_threadid_conversation[message.channel.id] = chat_history
124
+ if len(response) > 1300:
125
+ response = response[:1300] + "...\nTruncating response due to discord api limits."
126
+
127
+ await message.reply(response)
128
+
129
+ total_characters = 0
130
+ for item in conversation:
131
+ for string in item:
132
+ total_characters += len(string)
133
+
134
+ if total_characters >= 15000:
135
+ await message.reply("Conversation ending due to length, feel free to start a new one!")
136
+
137
+ except Exception as e:
138
+ print(f"continue_codellama Error: {e}")
139
+ await message.reply(f"Error: {e} <@811235357663297546> (continue_codellama error)")