Spaces:
Sleeping
Sleeping
import re | |
# ์์ฑ๋ ๋ชจ๋ ๋ด ์๋ต ๊ธฐ๋ก | |
def generate_reply(ctx, makePipeLine): | |
# ์ต์ด ์๋ต | |
response = generate_valid_response(ctx, makePipeLine) | |
ctx.addHistory("bot", response) | |
# ๋ถ์์ ํ ์๋ต์ด ์ ๋๋๋ฏ๋ก ์ฌ์ฉํ์ง ์์ | |
''' | |
# ์๋ต์ด ๋๊ฒผ๋ค๋ฉด ์ถ๊ฐ ์์ฑ | |
if is_truncated_response(response): | |
continuation = generate_valid_response(ctx, makePipeLine, response) | |
ctx.addHistory("bot", continuation) | |
''' | |
# ๋ด ์๋ต 1ํ ์์ฑ | |
def generate_valid_response(ctx, makePipeline) -> str: | |
user_name = ctx.getUserName() | |
bot_name = ctx.getBotName() | |
while True: | |
prompt = build_prompt(ctx.getHistory(), user_name, bot_name) | |
print("\n==========[DEBUG: Prompt]==========") | |
print(prompt) | |
print("===================================\n") | |
full_text = makePipeline.character_chat(prompt) | |
response = extract_response(full_text) | |
if is_valid_response(response, user_name, bot_name): | |
break | |
return clean_response(response, bot_name) | |
# ์ ๋ ฅ ํ๋กฌํํธ ์ ๋ฆฌ | |
def build_prompt(history, user_name, bot_name): | |
with open("assets/prompt/init.txt", "r", encoding="utf-8") as f: | |
system_prompt = f.read().strip() | |
# ํ๋กฌํํธ ๊ตฌ์ฑ (ChatML ์คํ์ผ) | |
prompt = f"<|system|>\n{system_prompt}\n\n" | |
for turn in history[-16:]: | |
if turn["role"] == "user": | |
prompt += f"<|user|>\n{turn['text']}\n\n" | |
else: | |
prompt += f"<|assistant|>\n{turn['text']}\n\n" | |
# ๋ง์ง๋ง์ assistant ์๋ต ์ ๋ | |
prompt += "<|assistant|>\n" | |
return prompt | |
# ์ถ๋ ฅ์์ ์๋ต ์ถ์ถ (HyperCLOVAX ํฌ๋งท์ ๋ง๊ฒ) | |
def extract_response(full_text): | |
# '### Response:' ์ดํ ํ ์คํธ ์ถ์ถ | |
if "### Response:" in full_text: | |
reply = full_text.split("### Response:")[-1].strip() | |
else: | |
reply = full_text.strip() | |
return reply | |
# ์๋ต ์ ํจ์ฑ ๊ฒ์ฌ | |
def is_valid_response(text: str, user_name, bot_name) -> bool: | |
if user_name + ":" in text: | |
return False | |
return True | |
# ์ถ๋ ฅ ์ ์ | |
def clean_response(text: str, bot_name): | |
# bot_name ์ ๊ฑฐ | |
text = re.sub(rf"{bot_name}:\s*", "", text).strip() | |
# ๋ฏธ์์ฑ ๋ฌธ์ฅ ์ ๊ฑฐ | |
return clean_truncated_response(text) | |
# ๋ฏธ์์ฑ ๋ฌธ์ฅ ์ญ์ | |
def clean_truncated_response(text: str) -> str: | |
""" | |
์๋ต ํ ์คํธ๊ฐ ๋ฏธ์์ฑ๋ ๋ฌธ์ฅ์ผ๋ก ๋๋๋ฉด ๋ง์ง๋ง ๋ฌธ์ฅ์ ์ ๊ฑฐํ์ฌ ๋ฐํ, | |
๊ทธ๋ ์ง ์์ผ๋ฉด ์๋ฌธ ๊ทธ๋๋ก ๋ฐํ. | |
""" | |
# ๋ฌธ์ฅ ๋ถ๋ฆฌ ('.', '?', '!', '~' ๋ฑ ๊ธฐ์ค + ์ค๋ฐ๊ฟ ํฌํจ) | |
sentence_end_pattern = r"(?<=[\.?!~])\s|\n" | |
segments = re.split(sentence_end_pattern, text.strip()) | |
if not segments: | |
return text.strip() | |
cleaned = [] | |
for s in segments: | |
s = s.strip() | |
if not s: | |
continue | |
# ๋ฌธ์ฅ ๋ถํธ๋ก ๋๋๋ ๊ฒฝ์ฐ๋ง ํฌํจ | |
if re.search(r"[.?!~โฆ\u2026\u2639\u263A\u2764\uD83D\uDE0A\uD83D\uDE22]$", s): | |
cleaned.append(s) | |
else: | |
break # ๋ถ์์ ํ ๋ฌธ์ฅ์ด๋ฏ๋ก ์ดํ ๋ชจ๋ ์ ๊ฑฐ | |
# ๋ง์ฝ ๋ชจ๋ ๋ฌธ์ฅ์ด ๋๋งบ์์ ์ ํ๋ค๋ฉด โ ์๋ฌธ ๋ฐํ | |
result = " ".join(cleaned) | |
return result if result != "" else text.strip() |