File size: 2,407 Bytes
aab927d
 
 
f8c71d2
aab927d
 
 
 
73dadd8
 
aab927d
 
 
 
73dadd8
aab927d
 
f8c71d2
aab927d
 
 
 
 
 
73dadd8
151bab0
aab927d
 
 
 
 
 
 
73dadd8
aab927d
73dadd8
 
aab927d
 
73dadd8
 
 
 
 
 
 
aab927d
73dadd8
 
 
aab927d
 
73dadd8
 
 
 
 
aab927d
73dadd8
aab927d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import re

# ์ƒ์„ฑ๋œ ๋ชจ๋“  ๋ด‡ ์‘๋‹ต ๊ธฐ๋ก
def generate_reply(ctx, makePipeLine, user_msg):
    # ์ตœ์ดˆ ์‘๋‹ต
    response = generate_valid_response(ctx, makePipeLine, user_msg)
    ctx.addHistory("bot", response)

    # ๋ถˆ์•ˆ์ •ํ•œ ์‘๋‹ต์ด ์œ ๋„๋˜๋ฏ€๋กœ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Œ
    '''
    # ์‘๋‹ต์ด ๋Š๊ฒผ๋‹ค๋ฉด ์ถ”๊ฐ€ ์ƒ์„ฑ
    if is_truncated_response(response):
        continuation = generate_valid_response(ctx, makePipeLine, response)
        ctx.addHistory("bot", continuation)
    '''

# ๋ด‡ ์‘๋‹ต 1ํšŒ ์ƒ์„ฑ
def generate_valid_response(ctx, makePipeline, user_msg) -> str:
    user_name = ctx.getUserName()
    bot_name = ctx.getBotName()

    while True:
        prompt = build_prompt(ctx.getHistory(), user_msg, user_name, bot_name)
        full_text = makePipeline.character_chat(prompt)
        response = extract_response(full_text)
        print(f"debug: {response}")
        if is_valid_response(response, user_name, bot_name):
            break
    return clean_response(response, bot_name)

# ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ ์ •๋ฆฌ
def build_prompt(history, user_msg, user_name, bot_name):
    with open("assets/prompt/init.txt", "r", encoding="utf-8") as f:
        system_prompt = f.read().strip()

    # ์ตœ๊ทผ ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ๋ฅผ ์ผ๋ฐ˜ ํ…์ŠคํŠธ๋กœ ์žฌ๊ตฌ์„ฑ
    dialogue = ""
    for turn in history[-16:]:
        role = user_name if turn["role"] == "user" else bot_name
        dialogue += f"{role}: {turn['text']}\n"

    dialogue += f"{user_name}: {user_msg}\n"

    # ๋ชจ๋ธ์— ๋งž๋Š” ํฌ๋งท ๊ตฌ์„ฑ
    prompt = f"""### Instruction:
{system_prompt}

{dialogue}
### Response:
{bot_name}:"""
    return prompt

# ์ถœ๋ ฅ์—์„œ ์‘๋‹ต ์ถ”์ถœ (HyperCLOVAX ํฌ๋งท์— ๋งž๊ฒŒ)
def extract_response(full_text):
    # '### Response:' ์ดํ›„ ํ…์ŠคํŠธ ์ถ”์ถœ
    if "### Response:" in full_text:
        reply = full_text.split("### Response:")[-1].strip()
    else:
        reply = full_text.strip()
    return reply

# ์‘๋‹ต ์œ ํšจ์„ฑ ๊ฒ€์‚ฌ
def is_valid_response(text: str, user_name, bot_name) -> bool:
    if user_name + ":" in text:
        return False
    return True

# ์‘๋‹ต ํ˜•์‹ ์ •๋ฆฌ
def clean_response(text: str, bot_name):
    return re.sub(rf"{bot_name}:\\s*", "", text).strip()

# ์ค‘๋‹จ๋œ ์‘๋‹ต ์—ฌ๋ถ€ ๊ฒ€์‚ฌ
def is_truncated_response(text: str) -> bool:
    return re.search(r"[.?!โ€ฆ\u2026\u2639\u263A\u2764\uD83D\uDE0A\uD83D\uDE22]$", text.strip()) is None