GodSaveMoney / core /launch_gradio.py
Jeong-hun Kim
minor fix
576221b
import gradio as gr
from core.context_manager import ContextManager
from core.make_pipeline import MakePipeline
from core.make_reply import generate_reply
from core.utils import load_config as load_full_config, save_config as save_full_config, load_llm_config
import re
def create_interface(ctx: ContextManager, makePipeline: MakePipeline):
with gr.Blocks(css="""
.chat-box { max-height: 500px; overflow-y: auto; padding: 10px; border: 1px solid #ccc; border-radius: 10px; }
.bubble-left { background-color: #f1f0f0; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: left; clear: both; }
.bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; }
.reset-btn-container { text-align: right; margin-bottom: 10px; }
""") as demo:
with gr.Tabs():
### 1. ์ฑ„ํŒ… ํƒญ ###
with gr.TabItem("๐Ÿ’ฌ ํƒ„์ง€๋กœ์™€ ๋Œ€ํ™”ํ•˜๊ธฐ"):
with gr.Column():
with gr.Row():
gr.Markdown("### ํƒ„์ง€๋กœ์™€ ๋Œ€ํ™”ํ•˜๊ธฐ")
reset_btn = gr.Button("๐Ÿ” ๋Œ€ํ™” ์ดˆ๊ธฐํ™”", elem_classes="reset-btn-container", scale=0.25)
chat_output = gr.HTML(elem_id="chat-box")
user_input = gr.Textbox(label="๋ฉ”์‹œ์ง€ ์ž…๋ ฅ", placeholder="ํƒ„์ง€๋กœ์—๊ฒŒ ๋ง์„ ๊ฑธ์–ด๋ณด์„ธ์š”")
state = gr.State(ctx)
# history ์ฝ์–ด์„œ ํ™”๋ฉด์— ๋ฟŒ๋ฆฌ๋Š” ์—ญํ• 
def render_chat(ctx: ContextManager):
def parse_emotion_text(text: str) -> str:
"""
*...* ๋ถ€๋ถ„์€ ํšŒ์ƒ‰ ํ…์ŠคํŠธ๋กœ ๋ฐ”๊พธ๊ณ , ์ค„๋ฐ”๊ฟˆ์„ ์ถ”๊ฐ€ํ•˜์—ฌ HTML๋กœ ๋ฐ˜ํ™˜
"""
segments = []
pattern = re.compile(r"\*(.+?)\*|([^\*]+)")
matches = pattern.findall(text)
for action, plain in matches:
if action:
segments.append(f"<div style='color:gray'>*{action}*</div>")
elif plain:
for line in plain.strip().splitlines():
line = line.strip()
if line:
segments.append(f"<div>{line}</div>")
return "\n".join(segments)
html = ""
for item in ctx.getHistory():
parsed = parse_emotion_text(item['text'])
if item["role"] == "user":
html += f"<div class='bubble-right'>{parsed}</div>"
elif item["role"] == "bot":
html += f"<div class='bubble-left'>{parsed}</div>"
return gr.update(value=html)
def on_submit(user_msg: str, ctx: ContextManager):
# ์‚ฌ์šฉ์ž ์ž…๋ ฅ history์— ์ถ”๊ฐ€
ctx.addHistory("user", user_msg)
# ์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํฌํ•จํ•œ ์ฑ„ํŒ… ์šฐ์„  ๋ Œ๋”๋ง
html = render_chat(ctx)
yield html, "", ctx
# ๋ด‡ ์‘๋‹ต ์ƒ์„ฑ
generate_reply(ctx, makePipeline)
# ์‘๋‹ต์„ ํฌํ•จํ•œ ์ „์ฒด history ๊ธฐ๋ฐ˜ ๋ Œ๋”๋ง
html = render_chat(ctx)
yield html, "", ctx
# history ์ดˆ๊ธฐํ™”
def reset_chat(ctx: ContextManager):
ctx.clearHistory()
return gr.update(value=""), "", ctx
user_input.submit(on_submit, inputs=[user_input, state], outputs=[chat_output, user_input, state], queue=True)
reset_btn.click(reset_chat, inputs=[state], outputs=[chat_output, user_input, state])
### 2. ์„ค์ • ํƒญ ###
with gr.TabItem("โš™๏ธ ๋ชจ๋ธ ์„ค์ •"):
gr.Markdown("### LLM ํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •")
with gr.Row():
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p")
repetition_penalty = gr.Slider(0.8, 2.0, value=1.05, step=0.01, label="Repetition Penalty")
with gr.Row():
max_tokens = gr.Slider(16, 2048, value=96, step=8, label="Max New Tokens")
apply_btn = gr.Button("โœ… ์„ค์ • ์ ์šฉ")
def update_config(temp, topp, max_tok, repeat):
makePipeline.update_config({
"temperature": temp,
"top_p": topp,
"max_new_tokens": max_tok,
"repetition_penalty": repeat
})
return gr.update(value="โœ… ์„ค์ • ์ ์šฉ ์™„๋ฃŒ")
# ๐Ÿ”ป ์„ค์ • ๋ถˆ๋Ÿฌ์˜ค๊ธฐ / ๋‚ด๋ณด๋‚ด๊ธฐ ๋ฒ„ํŠผ๋“ค
with gr.Row():
load_btn = gr.Button("๐Ÿ“‚ ์„ค์ • ๋ถˆ๋Ÿฌ์˜ค๊ธฐ")
save_btn = gr.Button("๐Ÿ’พ ์„ค์ • ๋‚ด๋ณด๋‚ด๊ธฐ")
def load_config():
llm_cfg = load_llm_config("config.json")
return (
llm_cfg.get("temperature", 0.7),
llm_cfg.get("top_p", 0.9),
llm_cfg.get("repetition_penalty", 1.05),
llm_cfg.get("max_new_tokens", 96),
"๐Ÿ“‚ ์„ค์ • ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์™„๋ฃŒ"
)
def save_config(temp, topp, repeat, max_tok):
# ๊ธฐ์กด ์ „์ฒด ์„ค์ • ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
config = load_full_config("config.json")
# LLM ๋ธ”๋ก๋งŒ ์ƒˆ๋กœ ๋Œ€์ž…
config["llm"] = {
"temperature": temp,
"top_p": topp,
"repetition_penalty": repeat,
"max_new_tokens": max_tok
}
# ์ „์ฒด ์ €์žฅ
save_full_config(config, path="config.json")
return gr.update(value="๐Ÿ’พ ์„ค์ • ์ €์žฅ ์™„๋ฃŒ")
# โœ… ๋งจ ์•„๋ž˜์— ์ƒํƒœ์ฐฝ ๋ฐฐ์น˜
status = gr.Textbox(label="", interactive=False)
# ๐Ÿ“‚ ๋ฒ„ํŠผ ๋™์ž‘ ์—ฐ๊ฒฐ
apply_btn.click(
update_config,
inputs=[temperature, top_p, max_tokens, repetition_penalty],
outputs=[status] # ํ˜น์€ []
)
load_btn.click(
load_config,
inputs=None,
outputs=[temperature, top_p, repetition_penalty, max_tokens, status]
)
save_btn.click(
save_config,
inputs=[temperature, top_p, repetition_penalty, max_tokens],
outputs=[status]
)
### 3. ํ”„๋กฌํ”„ํŠธ ํŽธ์ง‘ ํƒญ ###
with gr.TabItem("๐Ÿ“ ํ”„๋กฌํ”„ํŠธ ์„ค์ •"):
gr.Markdown("### ์บ๋ฆญํ„ฐ ๋ฐ ๋ฐฐ๊ฒฝ ๋กฌํ”„ํŠธ ํŽธ์ง‘")
prompt_editor = gr.Textbox(
lines=20,
label="ํ…์ŠคํŠธ (init.txt)",
placeholder="!! ๋ฐ˜๋“œ์‹œ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ๋ฅผ ๋จผ์ € ํ•˜์„ธ์š” !!",
interactive=True
)
with gr.Row():
gr.Markdown("#### !! ๋ฐ˜๋“œ์‹œ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ๋ฅผ ๋จผ์ € ํ•˜์„ธ์š” !!")
with gr.Row():
load_prompt_btn = gr.Button("๐Ÿ“‚ ํ˜„์žฌ ํ”„๋กฌํ”„ํŠธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ")
save_prompt_btn = gr.Button("๐Ÿ’พ ์ž‘์„ฑํ•œ ํ”„๋กฌํ”„ํŠธ๋กœ ๊ต์ฒด")
def load_prompt():
try:
with open("assets/prompt/init.txt", "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
return ""
def save_prompt(text):
with open("assets/prompt/init.txt", "w", encoding="utf-8") as f:
f.write(text)
return "๐Ÿ’พ ์ €์žฅ ์™„๋ฃŒ!"
load_prompt_btn.click(
load_prompt,
inputs=None,
outputs=prompt_editor
)
save_prompt_btn.click(
save_prompt,
inputs=[prompt_editor],
outputs=[save_prompt_btn]
)
return demo