| | import os
|
| | import csv
|
| | import time
|
| | import torch
|
| | import argparse
|
| | import chromadb
|
| | import datetime
|
| | import gradio as gr
|
| |
|
| | from groq import Groq
|
| | from pathlib import Path
|
| | from prompt_db import *
|
| | from chromadb.utils import embedding_functions
|
| |
|
| |
|
| | def get_chroma_collection(db_path: str, collection_name: str, *, embedf_name: str = "") -> chromadb.Collection | None:
|
| | """
|
| | ChromaDB ํด๋ผ์ด์ธํธ ๋ฐ ์ปฌ๋ ์
๋ก๋
|
| | input
|
| | dp_path : chromadb colletion์ด ์กด์ฌํ๋ ์ ๋ ๊ฒฝ๋ก
|
| | collection_name : chromadb colletion์ ์ด๋ฆ
|
| | output
|
| | collectoin : chromadb collection ๊ฐ์ฒด
|
| | """
|
| | if not os.path.exists(db_path):
|
| | print(f"collection {collection_name} ์(๋ฅผ) ์ฐพ์ ์ ์์ต๋๋ค. ๊ฒฝ๋ก๋ฅผ ๋ค์ ํ์ธํด์ฃผ์ธ์.")
|
| | return None
|
| |
|
| | chro_client = chromadb.PersistentClient(path=db_path)
|
| |
|
| | if embedf_name:
|
| | embed_fun = embedding_functions.SentenceTransformerEmbeddingFunction(
|
| | model_name = embedf_name,
|
| | device = "cuda" if torch.cuda.is_available() else "cpu"
|
| | )
|
| | print(f"์๋ฒ ๋ฉ ํจ์๋ก {embedf_name} ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ")
|
| | else:
|
| | embed_fun = embedding_functions.DefaultEmbeddingFunction()
|
| | print("์๋ฒ ๋ฉ ํจ์๋ก ๊ธฐ๋ณธ ์๋ฒ ๋ฉ ํจ์๋ฅผ ์ฌ์ฉํฉ๋๋ค. ")
|
| |
|
| |
|
| | try:
|
| | collection = chro_client.get_collection(
|
| | name = collection_name,
|
| | embedding_function = embed_fun)
|
| | print(f"Collection '{collection_name}' ์(๋ฅผ) ์ฑ๊ณต์ ์ผ๋ก ๋ถ๋ฌ์์ต๋๋ค. ")
|
| | return collection
|
| |
|
| | except Exception as e:
|
| | print(f"Collection '{collection_name}' ์(๋ฅผ) ๋ถ๋ฌ์ค์ง ๋ชปํ์ต๋๋ค : {e}")
|
| | return None
|
| |
|
| |
|
| | def query_db(collection: chromadb.Collection,
|
| | query_text: str,
|
| | n_results: int) -> str:
|
| | """
|
| | ์ฌ์ฉ์ ์ง๋ฌธ๊ณผ ๊ด๋ จ๋ ๋ฌธ์๋ฅผ DB(collection)์์ ๊ฒ์ํ์ฌ ๋ฐํ
|
| | input
|
| | collection :
|
| | query_text :
|
| | n_results :
|
| | output
|
| | data : ์ฌ์ฉ์์ ์ง๋ฌธ๊ณผ ๊ด๋ จ๋ ๋ฌธ์
|
| | """
|
| | if collection is None:
|
| | print("๋ฐ์ดํฐ๋ฒ ์ด์ค๊ฐ ์ฐ๊ฒฐ๋์ง ์์์ต๋๋ค.")
|
| | return ""
|
| |
|
| | try:
|
| | results = collection.query(
|
| | query_texts = [query_text],
|
| | n_results = n_results
|
| | )
|
| |
|
| |
|
| | if not results["documents"] or not results["documents"][0]:
|
| | print("๊ด๋ จ๋ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.")
|
| | return ""
|
| |
|
| |
|
| | documents = results["documents"][0]
|
| | metadatas = results["metadatas"][0]
|
| |
|
| | context_parts = []
|
| | for i, doc in enumerate(documents):
|
| | source = metadatas[i].get("title", "์ ๋ชฉ ์์")
|
| | date = metadatas[i].get("date", "๋ ์ง ์์")
|
| | context_parts.append(f"๋ฌธ์{i+1} [์ ๋ชฉ: {source}, ๋ ์ง: {date}]\n๋ด์ฉ : {doc}")
|
| |
|
| | data = "\n\n".join(context_parts)
|
| | return data
|
| |
|
| | except Exception as e:
|
| | print(f"๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
| | return ""
|
| |
|
| |
|
| | def save_log(base_dir, log_dir, request, user_message, assistant_message):
|
| | """
|
| | ๋ํ ๋ก๊ทธ ์ ์ฅ ํจ์
|
| | """
|
| | log_path = os.path.join(base_dir, log_dir)
|
| |
|
| | if not os.path.exists(log_path):
|
| | os.mkdir(log_path)
|
| | print(f"{log_dir} ํด๋๊ฐ ์์ฑ๋์์ต๋๋ค : {log_path}")
|
| |
|
| |
|
| |
|
| | today = datetime.datetime.now().strftime("%y%m%d")
|
| | file_name = f"chat_log_{today}.csv"
|
| | dest_file_path = os.path.join(log_path, file_name)
|
| |
|
| | if not os.path.exists(dest_file_path):
|
| | with open(dest_file_path, mode = "w", newline = "", encoding = "utf-8") as file:
|
| | writer = csv.writer(file)
|
| | writer.writerow(["user_ip", "time_stamp", "user_message", "assistant_message"])
|
| |
|
| |
|
| | user_ip = request.client.host if request else "Unknown_IP"
|
| | timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| | user_conv_log = [user_ip, timestamp, user_message, assistant_message]
|
| |
|
| | try:
|
| | with open(dest_file_path, mode = "a", newline = "", encoding = "utf-8") as file:
|
| | writer = csv.writer(file)
|
| | writer.writerow(user_conv_log)
|
| | except Exception as e:
|
| | print(f"๋ํ ๋ก๊ทธ ์ ์ฅ ์คํจ : {e}")
|
| |
|
| | def get_response(user_message: str,
|
| | system_prompt: str,
|
| | collection: chromadb.Collection,
|
| | history: list[dict | list],
|
| | request: gr.Request,
|
| | client: Groq,
|
| | base_dir: str,
|
| | log_dir: str,
|
| | model_name: str,
|
| | n_results: int,
|
| | temperature: float):
|
| |
|
| | if user_message.strip() == "๋๋":
|
| | end_message = "๋ํ๋ฅผ ์ข
๋ฃํฉ๋๋ค. ์ ๋ํ๋ฅผ ์์ํ๋ ค๋ฉด ์ค๋ฅธ์ชฝ ์๋จ์ Clear ๋ฒํผ(ํด์งํต ์์ด์ฝ)์ ํด๋ฆญํด์ฃผ์ธ์."
|
| | yield end_message
|
| | return
|
| |
|
| |
|
| | context = query_db(collection = collection,
|
| | query_text = user_message,
|
| | n_results= n_results)
|
| |
|
| |
|
| | formatted_system_prompt = system_prompt.format(context=context)
|
| |
|
| |
|
| | messages = [{"role": "system", "content": formatted_system_prompt}]
|
| |
|
| | for chat in history:
|
| | if isinstance(chat, dict):
|
| | messages.append({"role": chat["role"], "content": chat["content"]})
|
| |
|
| | elif isinstance(chat, list) and len(chat) == 2:
|
| | messages.append({"role": "user", "content": chat[0]})
|
| | messages.append({"role": "assistant", "content": chat[1]})
|
| |
|
| | messages.append({"role": "user", "content": user_message})
|
| |
|
| |
|
| | try:
|
| | response = client.chat.completions.create(
|
| | model = model_name,
|
| | messages = messages,
|
| | temperature = temperature,
|
| | stream = True
|
| | )
|
| |
|
| |
|
| | assistant_message = ""
|
| | for chunk in response:
|
| | delta = chunk.choices[0].delta.content
|
| | if delta:
|
| | assistant_message += delta
|
| | yield assistant_message
|
| |
|
| | except Exception as e:
|
| | error_message = f"๋ต๋ณ ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค. : {str(e)}"
|
| | yield error_message
|
| | assistant_message = error_message
|
| |
|
| | save_log(base_dir, log_dir, request, user_message, assistant_message)
|
| |
|
| |
|
| | def chat_with_rag(api_key: str,
|
| | collection: chromadb.Collection,
|
| | system_prompt: str,
|
| | args: argparse.ArgumentParser) -> None:
|
| | """
|
| | RAG ์ฑ๋ด ์คํ
|
| | input
|
| | dd
|
| | output
|
| | -
|
| | """
|
| | try:
|
| | groq_client = Groq(api_key = api_key)
|
| | except Exception as e:
|
| | print(f"Groq client๋ฅผ ๋ถ๋ฌ์ค์ง ๋ชปํ์ต๋๋ค. API Key๋ฅผ ํ์ธํด์ฃผ์ธ์ : {e}")
|
| |
|
| | def predict(user_message, history, request: gr.Request):
|
| | yield from get_response(
|
| | user_message = user_message,
|
| | system_prompt = system_prompt,
|
| | collection = collection,
|
| | history = history,
|
| | request = request,
|
| | client = groq_client,
|
| | base_dir = args.base_dir,
|
| | log_dir = args.log_dir,
|
| | model_name = args.model_name,
|
| | n_results = args.n_results,
|
| | temperature = args.temperature
|
| | )
|
| |
|
| | title = "ChaTech"
|
| | description = """
|
| | ์์ธ๊ณผํ๊ธฐ์ ๋ํ๊ต ๊ณต์ง์ฌํญ ๊ธฐ๋ฐ ์ง์์๋ต ์ฑ๋ด์
๋๋ค.
|
| | ๋ฐ์ดํฐ๋ฒ ์ด์ค์ ์ ์ฅ๋ ๊ณต์ง์ฌํญ ๋ด์ฉ์ ๋ฐํ์ผ๋ก ๋ต๋ณํฉ๋๋ค.
|
| | ๋ํ ์ข
๋ฃ๋ฅผ ์ํ์ค ๊ฒฝ์ฐ ์ฑํ
์ฐฝ์ \'๋๋\'์ ์
๋ ฅํด์ฃผ์ธ์.
|
| | """
|
| |
|
| | demo = gr.ChatInterface(
|
| | fn = predict,
|
| | title = title,
|
| | description = description
|
| | ).queue()
|
| |
|
| | demo.launch(debug = True, share = True)
|
| |
|
| |
|
| |
|
| | def get_system_prompt(prompt_type: str) -> str:
|
| | """
|
| | prompt_db.py๋ก๋ถํฐ ์์คํ
ํ๋กฌํํธ๋ฅผ ๋ถ๋ฌ์์ ๋ฐํ
|
| | input
|
| | prompt_type : ์ฌ์ฉํ ์์คํ
ํ๋กฌํํธ ์ข
๋ฅ
|
| | v : vanilla prompt
|
| | adv1 : advanced prompt ver.1 (๋ฏธ๊ตฌํ)
|
| | output
|
| | system_prompt : ์์คํ
ํ๋กฌํํธ ์ ๋ฌธ
|
| | """
|
| |
|
| | if prompt_type == "v":
|
| | vanilla = Vanilla()
|
| | system_prompt = vanilla.get_prompt()
|
| | return system_prompt
|
| |
|
| |
|
| | elif prompt_type == "adv1":
|
| | system_prompt = ""
|
| | return system_prompt
|
| | else:
|
| | print("์ ํจํ์ง ์์ ํ๋กฌํํธ ํ์
์
๋๋ค. ๊ธฐ๋ณธ๊ฐ(Vanilla)์ ์ฌ์ฉํฉ๋๋ค. ")
|
| | system_prompt = vanilla.get_prompt()
|
| | return system_prompt
|
| |
|
| |
|
| | def main(args):
|
| |
|
| | abs_db_path = os.path.join(args.base_dir, args.db_dir)
|
| |
|
| |
|
| | collection = get_chroma_collection(abs_db_path, args.collection_name)
|
| |
|
| |
|
| |
|
| | if collection is None:
|
| | print("Chromadb Collection์ ๋ถ๋ฌ์ค์ง ๋ชปํ์ต๋๋ค. ํ๋ก๊ทธ๋จ์ ์ข
๋ฃํฉ๋๋ค. ")
|
| | return
|
| |
|
| |
|
| | system_prompt = get_system_prompt(args.prompt_type)
|
| |
|
| |
|
| | chat_with_rag(api_key = args.api_key,
|
| | collection = collection,
|
| | system_prompt = system_prompt,
|
| | args = args)
|
| |
|
| | if __name__ == "__main__":
|
| | parser = argparse.ArgumentParser()
|
| | parser.add_argument("--api_key", type = str, default = "")
|
| | parser.add_argument("--base_dir", type = str, default = str(Path(__file__).resolve().parent))
|
| | parser.add_argument("--db_dir", type = str, default = "seoultech_data_db")
|
| | parser.add_argument("--log_dir", type = str, default = "chat_log")
|
| | parser.add_argument("--model_name", type = str, default = "llama-3.3-70b-versatile")
|
| | parser.add_argument("--temperature", type = float, default = 0.5)
|
| | parser.add_argument("--n_results", type = int, default = 3)
|
| | parser.add_argument("--collection_name", type = str, default = "seoultech_notices")
|
| | parser.add_argument("--embedf_name", type = str, default = "BAAI/bge-m3")
|
| | parser.add_argument("--prompt_type", type = str, default = "v")
|
| |
|
| | args = parser.parse_args()
|
| | main(args)
|
| |
|
| |
|
| |
|
| |
|