File size: 2,957 Bytes
56487d0
21c61a3
de850e8
 
56487d0
 
 
de850e8
273182d
de850e8
 
 
 
 
19ceb64
 
de850e8
 
 
 
 
 
 
 
56487d0
 
273182d
 
 
 
 
 
 
 
 
 
 
56487d0
 
 
 
 
 
 
 
 
273182d
56487d0
 
 
 
273182d
 
56487d0
273182d
21c61a3
 
273182d
 
56487d0
 
21c61a3
273182d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21c61a3
 
56487d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import gradio as gr
from kiwipiepy import Kiwi

from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_community.document_transformers import LongContextReorder

from libs.config import STREAMING
from libs.embeddings import get_embeddings
from libs.retrievers import load_retrievers
from libs.llm import get_llm
from libs.prompt import get_prompt

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def kiwi_tokenize(text):
    kiwi = Kiwi()
    return [token.form for token in kiwi.tokenize(text)]


embeddings = get_embeddings()
retriever = load_retrievers(embeddings)


def kiwi_tokenize(text):
    kiwi = Kiwi()
    return [token.form for token in kiwi.tokenize(text)]


embeddings = get_embeddings()
retriever = load_retrievers(embeddings)


def create_rag_chain(chat_history, model):
    llm = get_llm(streaming=STREAMING).with_config(configurable={"llm": model})
    prompt = get_prompt(chat_history)

    return (
        {
            "context": retriever
            | RunnableLambda(LongContextReorder().transform_documents),
            "question": RunnablePassthrough(),
        }
        | prompt
        | llm
        | StrOutputParser()
    )


def respond_stream(message, history, model):
    rag_chain = create_rag_chain(history, model)
    for chunk in rag_chain.stream(message):
        yield chunk


def respond(message, history, model):
    rag_chain = create_rag_chain(history, model)
    return rag_chain.invoke(message)


# μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λΈ λͺ©λ‘ (key: λͺ¨λΈ μ‹λ³„μž, value: μ‚¬μš©μžμ—κ²Œ ν‘œμ‹œν•  λ ˆμ΄λΈ”)
AVAILABLE_MODELS = {
    "gpt_3_5_turbo": "GPT-3.5 Turbo",
    "gpt_4o": "GPT-4o",
    "claude_3_5_sonnet": "Claude 3.5 Sonnet",
    "gemini_1_5_flash": "Gemini 1.5 Flash",
    "llama3_70b": "Llama3 70b",
}


def get_model_key(label):
    return next(key for key, value in AVAILABLE_MODELS.items() if value == label)


def chat_function(message, history, model_label):
    model_key = get_model_key(model_label)
    if STREAMING:
        response = ""
        for chunk in respond_stream(message, history, model_key):
            response += chunk
            yield response
    else:
        response = respond(message, history, model_key)
        yield response


with gr.Blocks() as demo:
    gr.Markdown("# λŒ€λ²•μ› νŒλ‘€ 상담 λ„μš°λ―Έ")
    gr.Markdown(
        "μ•ˆλ…•ν•˜μ„Έμš”! λŒ€λ²•μ› νŒλ‘€μ— κ΄€ν•œ μ§ˆλ¬Έμ— λ‹΅λ³€ν•΄λ“œλ¦¬λŠ” AI 상담 λ„μš°λ―Έμž…λ‹ˆλ‹€. νŒλ‘€ 검색, 해석, 적용 등에 λŒ€ν•΄ κΆκΈˆν•˜μ‹  점이 있으면 μ–Έμ œλ“  λ¬Όμ–΄λ³΄μ„Έμš”."
    )

    model_dropdown = gr.Dropdown(
        choices=list(AVAILABLE_MODELS.values()),
        label="λͺ¨λΈ 선택",
        value=list(AVAILABLE_MODELS.values())[1],
    )

    chatbot = gr.ChatInterface(
        fn=chat_function,
        additional_inputs=[model_dropdown],
    )

if __name__ == "__main__":
    demo.launch()