Spaces:
Sleeping
Sleeping
import gradio as gr | |
from core.context_manager import ContextManager | |
from core.make_pipeline import MakePipeline | |
from core.make_reply import generate_reply | |
from core.utils import load_config as load_full_config, save_config as save_full_config, load_llm_config | |
import re | |
def create_interface(ctx: ContextManager, makePipeline: MakePipeline): | |
with gr.Blocks(css=""" | |
.chat-box { max-height: 500px; overflow-y: auto; padding: 10px; border: 1px solid #ccc; border-radius: 10px; } | |
.bubble-left { background-color: #f1f0f0; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: left; clear: both; } | |
.bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; } | |
.reset-btn-container { text-align: right; margin-bottom: 10px; } | |
""") as demo: | |
with gr.Tabs(): | |
### 1. ์ฑํ ํญ ### | |
with gr.TabItem("๐ฌ ํ์ง๋ก์ ๋ํํ๊ธฐ"): | |
with gr.Column(): | |
with gr.Row(): | |
gr.Markdown("### ํ์ง๋ก์ ๋ํํ๊ธฐ") | |
reset_btn = gr.Button("๐ ๋ํ ์ด๊ธฐํ", elem_classes="reset-btn-container", scale=0.25) | |
chat_output = gr.HTML(elem_id="chat-box") | |
user_input = gr.Textbox(label="๋ฉ์์ง ์ ๋ ฅ", placeholder="ํ์ง๋ก์๊ฒ ๋ง์ ๊ฑธ์ด๋ณด์ธ์") | |
state = gr.State(ctx) | |
# history ์ฝ์ด์ ํ๋ฉด์ ๋ฟ๋ฆฌ๋ ์ญํ | |
def render_chat(ctx: ContextManager): | |
def parse_emotion_text(text: str) -> str: | |
""" | |
*...* ๋ถ๋ถ์ ํ์ ํ ์คํธ๋ก ๋ฐ๊พธ๊ณ , ์ค๋ฐ๊ฟ์ ์ถ๊ฐํ์ฌ HTML๋ก ๋ฐํ | |
""" | |
segments = [] | |
pattern = re.compile(r"\*(.+?)\*|([^\*]+)") | |
matches = pattern.findall(text) | |
for action, plain in matches: | |
if action: | |
segments.append(f"<div style='color:gray'>*{action}*</div>") | |
elif plain: | |
for line in plain.strip().splitlines(): | |
line = line.strip() | |
if line: | |
segments.append(f"<div>{line}</div>") | |
return "\n".join(segments) | |
html = "" | |
for item in ctx.getHistory(): | |
parsed = parse_emotion_text(item['text']) | |
if item["role"] == "user": | |
html += f"<div class='bubble-right'>{parsed}</div>" | |
elif item["role"] == "bot": | |
html += f"<div class='bubble-left'>{parsed}</div>" | |
return gr.update(value=html) | |
def on_submit(user_msg: str, ctx: ContextManager): | |
# ์ฌ์ฉ์ ์ ๋ ฅ history์ ์ถ๊ฐ | |
ctx.addHistory("user", user_msg) | |
# ์ฌ์ฉ์ ์ ๋ ฅ์ ํฌํจํ ์ฑํ ์ฐ์ ๋ ๋๋ง | |
html = render_chat(ctx) | |
yield html, "", ctx | |
# ๋ด ์๋ต ์์ฑ | |
generate_reply(ctx, makePipeline) | |
# ์๋ต์ ํฌํจํ ์ ์ฒด history ๊ธฐ๋ฐ ๋ ๋๋ง | |
html = render_chat(ctx) | |
yield html, "", ctx | |
# history ์ด๊ธฐํ | |
def reset_chat(ctx: ContextManager): | |
ctx.clearHistory() | |
return gr.update(value=""), "", ctx | |
user_input.submit(on_submit, inputs=[user_input, state], outputs=[chat_output, user_input, state], queue=True) | |
reset_btn.click(reset_chat, inputs=[state], outputs=[chat_output, user_input, state]) | |
### 2. ์ค์ ํญ ### | |
with gr.TabItem("โ๏ธ ๋ชจ๋ธ ์ค์ "): | |
gr.Markdown("### LLM ํ๋ผ๋ฏธํฐ ์ค์ ") | |
with gr.Row(): | |
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") | |
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p") | |
repetition_penalty = gr.Slider(0.8, 2.0, value=1.05, step=0.01, label="Repetition Penalty") | |
with gr.Row(): | |
max_tokens = gr.Slider(16, 2048, value=96, step=8, label="Max New Tokens") | |
apply_btn = gr.Button("โ ์ค์ ์ ์ฉ") | |
def update_config(temp, topp, max_tok, repeat): | |
makePipeline.update_config({ | |
"temperature": temp, | |
"top_p": topp, | |
"max_new_tokens": max_tok, | |
"repetition_penalty": repeat | |
}) | |
return gr.update(value="โ ์ค์ ์ ์ฉ ์๋ฃ") | |
# ๐ป ์ค์ ๋ถ๋ฌ์ค๊ธฐ / ๋ด๋ณด๋ด๊ธฐ ๋ฒํผ๋ค | |
with gr.Row(): | |
load_btn = gr.Button("๐ ์ค์ ๋ถ๋ฌ์ค๊ธฐ") | |
save_btn = gr.Button("๐พ ์ค์ ๋ด๋ณด๋ด๊ธฐ") | |
def load_config(): | |
llm_cfg = load_llm_config("config.json") | |
return ( | |
llm_cfg.get("temperature", 0.7), | |
llm_cfg.get("top_p", 0.9), | |
llm_cfg.get("repetition_penalty", 1.05), | |
llm_cfg.get("max_new_tokens", 96), | |
"๐ ์ค์ ๋ถ๋ฌ์ค๊ธฐ ์๋ฃ" | |
) | |
def save_config(temp, topp, repeat, max_tok): | |
# ๊ธฐ์กด ์ ์ฒด ์ค์ ๋ถ๋ฌ์ค๊ธฐ | |
config = load_full_config("config.json") | |
# LLM ๋ธ๋ก๋ง ์๋ก ๋์ | |
config["llm"] = { | |
"temperature": temp, | |
"top_p": topp, | |
"repetition_penalty": repeat, | |
"max_new_tokens": max_tok | |
} | |
# ์ ์ฒด ์ ์ฅ | |
save_full_config(config, path="config.json") | |
return gr.update(value="๐พ ์ค์ ์ ์ฅ ์๋ฃ") | |
# โ ๋งจ ์๋์ ์ํ์ฐฝ ๋ฐฐ์น | |
status = gr.Textbox(label="", interactive=False) | |
# ๐ ๋ฒํผ ๋์ ์ฐ๊ฒฐ | |
apply_btn.click( | |
update_config, | |
inputs=[temperature, top_p, max_tokens, repetition_penalty], | |
outputs=[status] # ํน์ [] | |
) | |
load_btn.click( | |
load_config, | |
inputs=None, | |
outputs=[temperature, top_p, repetition_penalty, max_tokens, status] | |
) | |
save_btn.click( | |
save_config, | |
inputs=[temperature, top_p, repetition_penalty, max_tokens], | |
outputs=[status] | |
) | |
### 3. ํ๋กฌํํธ ํธ์ง ํญ ### | |
with gr.TabItem("๐ ํ๋กฌํํธ ์ค์ "): | |
gr.Markdown("### ์บ๋ฆญํฐ ๋ฐ ๋ฐฐ๊ฒฝ ๋กฌํํธ ํธ์ง") | |
prompt_editor = gr.Textbox( | |
lines=20, | |
label="ํ ์คํธ (init.txt)", | |
placeholder="!! ๋ฐ๋์ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ๋จผ์ ํ์ธ์ !!", | |
interactive=True | |
) | |
with gr.Row(): | |
gr.Markdown("#### !! ๋ฐ๋์ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ๋จผ์ ํ์ธ์ !!") | |
with gr.Row(): | |
load_prompt_btn = gr.Button("๐ ํ์ฌ ํ๋กฌํํธ ๋ถ๋ฌ์ค๊ธฐ") | |
save_prompt_btn = gr.Button("๐พ ์์ฑํ ํ๋กฌํํธ๋ก ๊ต์ฒด") | |
def load_prompt(): | |
try: | |
with open("assets/prompt/init.txt", "r", encoding="utf-8") as f: | |
return f.read() | |
except FileNotFoundError: | |
return "" | |
def save_prompt(text): | |
with open("assets/prompt/init.txt", "w", encoding="utf-8") as f: | |
f.write(text) | |
return "๐พ ์ ์ฅ ์๋ฃ!" | |
load_prompt_btn.click( | |
load_prompt, | |
inputs=None, | |
outputs=prompt_editor | |
) | |
save_prompt_btn.click( | |
save_prompt, | |
inputs=[prompt_editor], | |
outputs=[save_prompt_btn] | |
) | |
return demo |