import os import gradio as gr from kiwipiepy import Kiwi from typing import List, Tuple, Generator, Union from langchain_core.output_parsers import StrOutputParser from langchain_core.messages import HumanMessage, AIMessage from langchain_core.runnables import RunnablePassthrough, RunnableLambda from langchain_community.document_transformers import LongContextReorder from libs.config import STREAMING from libs.embeddings import get_embeddings from libs.retrievers import load_retrievers from libs.llm import get_llm from libs.prompt import get_prompt os.environ["TOKENIZERS_PARALLELISM"] = "false" def kiwi_tokenize(text): kiwi = Kiwi() return [token.form for token in kiwi.tokenize(text)] embeddings = get_embeddings() retriever = load_retrievers(embeddings) # 사용 가능한 모델 목록 (key: 모델 식별자, value: 사용자에게 표시할 레이블) AVAILABLE_MODELS = { # "gpt_3_5_turbo": "GPT-3.5 Turbo", "gpt_4o": "GPT-4o", "gemini_1_5_flash": "Gemini 1.5 Flash", "claude_3_5_sonnet": "Claude 3.5 Sonnet", # "llama3_70b": "Llama3 70b", } def create_rag_chain(chat_history: List[Tuple[str, str]], model: str): langchain_messages = [] for human, ai in chat_history: langchain_messages.append(HumanMessage(content=human)) langchain_messages.append(AIMessage(content=ai)) llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model}) prompt = get_prompt().partial(history=langchain_messages) return ( { "context": retriever | RunnableLambda(LongContextReorder().transform_documents), "question": RunnablePassthrough(), } | prompt | llm | StrOutputParser() ) def get_model_key(label): return next(key for key, value in AVAILABLE_MODELS.items() if value == label) def respond_stream( message: str, history: List[Tuple[str, str]], model: str ) -> Generator[str, None, None]: rag_chain = create_rag_chain(history, model) for chunk in rag_chain.stream(message): yield chunk def respond(message: str, history: List[Tuple[str, str]], model: str) -> str: rag_chain = create_rag_chain(history, model) return rag_chain.invoke(message) def get_model_key(label: str) -> str: return next(key for key, value in AVAILABLE_MODELS.items() if value == label) def validate_input(message: str) -> bool: """입력된 메시지가 유효한지 검사합니다.""" return bool(message.strip()) def chat_function( message: str, history: List[Tuple[str, str]], model_label: str ) -> Generator[str, None, None]: if not validate_input(message): yield "메시지를 입력해주세요." return model_key = get_model_key(model_label) if STREAMING: response = "" for chunk in respond_stream(message, history, model_key): response += chunk yield response else: response = respond(message, history, model_key) yield response with gr.Blocks( fill_height=True, ) as demo: gr.Markdown("# 대법원 판례 상담 도우미") gr.Markdown( "안녕하세요! 대법원 판례에 관한 질문에 답변해드리는 AI 상담 도우미입니다. 판례 검색, 해석, 적용 등에 대해 궁금하신 점이 있으면 언제든 물어보세요." ) model_dropdown = gr.Dropdown( choices=list(AVAILABLE_MODELS.values()), label="모델 선택", value=list(AVAILABLE_MODELS.values())[0], ) chatbot = gr.ChatInterface( fn=chat_function, autofocus=True, fill_height=True, multimodal=False, examples=[ [ "중고차 거래를 했는데 불량으로 차 수리에 500만원이 들었습니다. 판매자에게 법적 책임을 물을 수 있나요? 비슷한 판례를 소개해주세요.", "GPT-4o", ], [ "약 2천 평의 농지를 구매했는데, 알고 보니 주택을 지을 수 없는 땅이었습니다. 이와 유사한 부동산 사기 관련 판례를 알려주세요.", "GPT-4o", ], [ "지인이 장난으로 휘두른 칼에 팔이 20cm 가량 찔렸습니다. 장난이라고 주장하는데, 이와 유사한 상해 관련 판례를 알려주세요.", "GPT-4o", ], ], additional_inputs=[model_dropdown], ) if __name__ == "__main__": demo.launch()