KaiShin1885 commited on
Commit
221d902
ยท
verified ยท
1 Parent(s): 0272e5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -18
app.py CHANGED
@@ -4,6 +4,8 @@ import os
4
  from huggingface_hub import InferenceClient
5
  import asyncio
6
  import subprocess
 
 
7
 
8
  # ๋กœ๊น… ์„ค์ •
9
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
@@ -16,7 +18,7 @@ intents.guilds = True
16
  intents.guild_messages = True
17
 
18
  # ์ถ”๋ก  API ํด๋ผ์ด์–ธํŠธ ์„ค์ •
19
- hf_client = InferenceClient("CohereForAI/aya-23-8B", token=os.getenv("HF_TOKEN"))
20
 
21
  # ํŠน์ • ์ฑ„๋„ ID
22
  SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID"))
@@ -24,6 +26,21 @@ SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID"))
24
  # ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ๋ฅผ ์ €์žฅํ•  ์ „์—ญ ๋ณ€์ˆ˜
25
  conversation_history = []
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class MyClient(discord.Client):
28
  def __init__(self, *args, **kwargs):
29
  super().__init__(*args, **kwargs)
@@ -42,15 +59,9 @@ class MyClient(discord.Client):
42
  if self.is_processing:
43
  return
44
  self.is_processing = True
45
-
46
  try:
47
- if not isinstance(message.channel, discord.Thread):
48
- thread = await message.create_thread(name=f"๋…ผ๋ฌธ ์ž‘์„ฑ - {message.author.display_name}", auto_archive_duration=60)
49
- else:
50
- thread = message.channel
51
-
52
  response = await generate_response(message)
53
- await thread.send(response)
54
  finally:
55
  self.is_processing = False
56
 
@@ -59,14 +70,17 @@ class MyClient(discord.Client):
59
  isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID
60
  )
61
 
62
-
63
  async def generate_response(message):
64
  global conversation_history
65
  user_input = message.content
66
  user_mention = message.author.mention
67
- system_message = f"{user_mention}, Discord์—์„œ ์‚ฌ์šฉ์ž๋“ค์˜ ์งˆ๋ฌธ์— ๋‹ตํ•˜๋Š” ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค."
 
 
 
 
68
  system_prefix = """
69
-
70
  1. ์ฃผ์ œ์— ๋”ฐ๋ฅธ ๋ฌธ๋งฅ ์ดํ•ด์— ๋งž๋Š” ๊ธ€์„ ์จ์ฃผ์„ธ์š”.
71
  2. ์ฃผ์ œ์™€ ์ƒํ™ฉ์— ๋งž๋Š” ์ ์ ˆํ•œ ์–ดํœ˜๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”
72
  3. ํ•œ๊ตญ ๋ฌธํ™”์™€ ์ ํ•ฉ์„ฑ๋ฅผ ๊ณ ๋ คํ•ด์ฃผ์„ธ์š”
@@ -188,28 +202,48 @@ async def generate_response(message):
188
  """
189
 
190
  conversation_history.append({"role": "user", "content": user_input})
191
- logging.debug(f'Conversation history updated: {conversation_history}')
192
-
193
  messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}] + conversation_history
 
 
 
 
194
  logging.debug(f'Messages to be sent to the model: {messages}')
195
-
196
  loop = asyncio.get_event_loop()
197
  response = await loop.run_in_executor(None, lambda: hf_client.chat_completion(
198
  messages, max_tokens=1000, stream=True, temperature=0.7, top_p=0.85))
199
-
200
  full_response = []
201
  for part in response:
202
  logging.debug(f'Part received from stream: {part}')
203
  if part.choices and part.choices[0].delta and part.choices[0].delta.content:
204
  full_response.append(part.choices[0].delta.content)
205
-
206
  full_response_text = ''.join(full_response)
207
  logging.debug(f'Full model response: {full_response_text}')
208
-
209
  conversation_history.append({"role": "assistant", "content": full_response_text})
210
  return f"{user_mention}, {full_response_text}"
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  if __name__ == "__main__":
213
  discord_client = MyClient(intents=intents)
214
  discord_client.run(os.getenv('DISCORD_TOKEN'))
215
-
 
4
  from huggingface_hub import InferenceClient
5
  import asyncio
6
  import subprocess
7
+ from datasets import load_dataset
8
+ from sentence_transformers import SentenceTransformer, util
9
 
10
  # ๋กœ๊น… ์„ค์ •
11
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
 
18
  intents.guild_messages = True
19
 
20
  # ์ถ”๋ก  API ํด๋ผ์ด์–ธํŠธ ์„ค์ •
21
+ hf_client = InferenceClient("CohereForAI/c4ai-command-r-08-2024", token=os.getenv("HF_TOKEN"))
22
 
23
  # ํŠน์ • ์ฑ„๋„ ID
24
  SPECIFIC_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID"))
 
26
  # ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ๋ฅผ ์ €์žฅํ•  ์ „์—ญ ๋ณ€์ˆ˜
27
  conversation_history = []
28
 
29
+ # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
30
+ datasets = [
31
+ ("all-processed", "all-processed"),
32
+ ("chatdoctor-icliniq", "chatdoctor-icliniq"),
33
+ ("chatdoctor_healthcaremagic", "chatdoctor_healthcaremagic"),
34
+ # ... (๋‚˜๋จธ์ง€ ๋ฐ์ดํ„ฐ์…‹)
35
+ ]
36
+
37
+ all_datasets = {}
38
+ for dataset_name, config in datasets:
39
+ all_datasets[dataset_name] = load_dataset("lavita/medical-qa-datasets", config)
40
+
41
+ # ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
42
+ model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
43
+
44
  class MyClient(discord.Client):
45
  def __init__(self, *args, **kwargs):
46
  super().__init__(*args, **kwargs)
 
59
  if self.is_processing:
60
  return
61
  self.is_processing = True
 
62
  try:
 
 
 
 
 
63
  response = await generate_response(message)
64
+ await message.channel.send(response)
65
  finally:
66
  self.is_processing = False
67
 
 
70
  isinstance(message.channel, discord.Thread) and message.channel.parent_id == SPECIFIC_CHANNEL_ID
71
  )
72
 
 
73
  async def generate_response(message):
74
  global conversation_history
75
  user_input = message.content
76
  user_mention = message.author.mention
77
+
78
+ # ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ ์ฐพ๊ธฐ
79
+ most_similar_data = find_most_similar_data(user_input)
80
+
81
+ system_message = f"{user_mention}, DISCORD์—์„œ ์‚ฌ์šฉ์ž๋“ค์˜ ์งˆ๋ฌธ์— ๋‹ตํ•˜๋Š” ์–ด์‹œ์Šคํ„ดํŠธ์ž…๋‹ˆ๋‹ค."
82
  system_prefix = """
83
+
84
  1. ์ฃผ์ œ์— ๋”ฐ๋ฅธ ๋ฌธ๋งฅ ์ดํ•ด์— ๋งž๋Š” ๊ธ€์„ ์จ์ฃผ์„ธ์š”.
85
  2. ์ฃผ์ œ์™€ ์ƒํ™ฉ์— ๋งž๋Š” ์ ์ ˆํ•œ ์–ดํœ˜๋ฅผ ์„ ํƒํ•ด์ฃผ์„ธ์š”
86
  3. ํ•œ๊ตญ ๋ฌธํ™”์™€ ์ ํ•ฉ์„ฑ๋ฅผ ๊ณ ๋ คํ•ด์ฃผ์„ธ์š”
 
202
  """
203
 
204
  conversation_history.append({"role": "user", "content": user_input})
 
 
205
  messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}] + conversation_history
206
+
207
+ if most_similar_data:
208
+ messages.append({"role": "system", "content": f"๊ด€๋ จ ์ •๋ณด: {most_similar_data}"})
209
+
210
  logging.debug(f'Messages to be sent to the model: {messages}')
211
+
212
  loop = asyncio.get_event_loop()
213
  response = await loop.run_in_executor(None, lambda: hf_client.chat_completion(
214
  messages, max_tokens=1000, stream=True, temperature=0.7, top_p=0.85))
215
+
216
  full_response = []
217
  for part in response:
218
  logging.debug(f'Part received from stream: {part}')
219
  if part.choices and part.choices[0].delta and part.choices[0].delta.content:
220
  full_response.append(part.choices[0].delta.content)
221
+
222
  full_response_text = ''.join(full_response)
223
  logging.debug(f'Full model response: {full_response_text}')
224
+
225
  conversation_history.append({"role": "assistant", "content": full_response_text})
226
  return f"{user_mention}, {full_response_text}"
227
 
228
+ def find_most_similar_data(query):
229
+ query_embedding = model.encode(query, convert_to_tensor=True)
230
+ most_similar = None
231
+ highest_similarity = -1
232
+
233
+ for dataset_name, dataset in all_datasets.items():
234
+ for split in dataset.keys():
235
+ for item in dataset[split]:
236
+ if 'question' in item and 'answer' in item:
237
+ item_text = f"์งˆ๋ฌธ: {item['question']} ๋‹ต๋ณ€: {item['answer']}"
238
+ item_embedding = model.encode(item_text, convert_to_tensor=True)
239
+ similarity = util.pytorch_cos_sim(query_embedding, item_embedding).item()
240
+
241
+ if similarity > highest_similarity:
242
+ highest_similarity = similarity
243
+ most_similar = item_text
244
+
245
+ return most_similar
246
+
247
  if __name__ == "__main__":
248
  discord_client = MyClient(intents=intents)
249
  discord_client.run(os.getenv('DISCORD_TOKEN'))