Fiction_Studio / app.py
callanwu's picture
fix bug
d4e0643
raw
history blame
13.8 kB
import sys
sys.path.append("../../Gradio_Config")
import os
from gradio_base import WebUI, UIHelper, PORT, HOST
from gradio_config import GradioConfig as gc
from gradio_config import StateConfig as sc
from typing import List
import gradio as gr
class NovelUI(WebUI):
node2show = {
"Node 1": "Write Character Settings and Script Outlines🖊️",
"Node 2": "Expand the first chapter<br>✍️",
"Node 3": "Expand the second chapter<br>✍️",
"Node 4": "Expand the third chapter<br>✍️",
"state1": "Perform the first plot<br>🎭",
"state2": "Perform the second plot<br>🎭",
"state3": "Perform the third plot<br>🎭",
"state4": "Perform the forth plot<br>🎭"
}
show2node = {}
def render_and_register_ui(self):
self.agent_name:list = [self.cache["agents_name"]] if isinstance(self.cache["agents_name"], str) else self.cache['agents_name']
gc.add_agent(self.agent_name)
def handle_message(self, history:list, record:list, state, agent_name, token, node_name):
RECORDER = True if state//10 ==2 else False
render_data:list = record if RECORDER else history
data:list = self.data_recorder if RECORDER else self.data_history
if state % 10 == 0:
data.append({agent_name: token})
elif state % 10 == 1:
# Same state. Need to add new bubble in same bubble.
data[-1][agent_name] += token
elif state % 10 == 2:
# New state. Need to add new bubble.
render_data.append([None, ""])
data.clear()
data.append({agent_name: token})
else:
assert False, "Invalid state."
render_data = self.render_bubble(render_data, data, node_name, render_node_name=True)
if RECORDER:
record = render_data
else:
history = render_data
return history, record
def update_progress(self, node_name, node_schedule):
DONE = True
node_name = self.node2show[node_name]
for idx, name in enumerate(self.cache['nodes_name']):
name = self.node2show[name]
self.progress_manage['show_type'][idx] = "active-show-up"
self.progress_manage['show_content'][idx] = ("" if name != node_name else "💬",)
if name == node_name:
DONE = False
self.progress_manage['schedule'][idx] = node_schedule
elif DONE:
self.progress_manage['schedule'][idx] = 100
elif DONE == False:
self.progress_manage['schedule'][idx] = 0
if self.cache['nodes_name'].index(self.show2node[node_name]) == len(self.cache['nodes_name']) - 2 and node_schedule == 100:
self.progress_manage['schedule'][-1] = 100
return sc.FORMAT.format(
sc.CSS,
sc.update_states(
current_states=self.progress_manage["schedule"],
current_templates=self.progress_manage["show_type"],
show_content=self.progress_manage["show_content"]
)
)
def __init__(
self,
client_cmd: list,
socket_host: str = HOST,
socket_port: int = PORT,
bufsize: int = 1024,
ui_name: str = "NovelUI"
):
super(NovelUI, self).__init__(client_cmd, socket_host, socket_port, bufsize, ui_name)
self.first_recieve_from_client()
for item in ['agents_name', 'nodes_name', 'output_file_path', 'requirement']:
assert item in self.cache
self.progress_manage = {
"schedule": [None for _ in range(len(self.cache['nodes_name']))],
"show_type": [None for _ in range(len(self.cache['nodes_name']))],
"show_content": [None for _ in range(len(self.cache['nodes_name']))]
}
NovelUI.show2node = {NovelUI.node2show[_]:_ for _ in NovelUI.node2show.keys()}
def construct_ui(self):
with gr.Blocks(css=gc.CSS) as demo:
with gr.Column():
self.progress = gr.HTML(
value=sc.FORMAT.format(
sc.CSS,
sc.create_states([NovelUI.node2show[name] for name in self.cache['nodes_name']], False)
)
)
with gr.Row():
with gr.Column(scale=6):
self.chatbot = gr.Chatbot(
elem_id="chatbot1",
label="Dialog",
height=500
)
self.text_api = gr.Textbox(
value = self.cache["api_key"],
placeholder="openai key",
label="Please input valid openai key for gpt-3.5-turbo-16k."
)
with gr.Row():
self.text_requirement = gr.Textbox(
placeholder="Requirement of the novel",
value=self.cache['requirement'],
scale=9
)
self.btn_start = gr.Button(
value="Start",
scale=1
)
self.btn_reset = gr.Button(
value="Restart",
visible=False
)
with gr.Column(scale=5):
self.chat_record = gr.Chatbot(
elem_id="chatbot1",
label="Record",
visible=False
)
self.file_show = gr.File(
value=[],
label="FileList",
visible=False
)
self.chat_show = gr.Chatbot(
elem_id="chatbot1",
label="FileRead",
visible=False
)
# ===============Event Listener===============
self.btn_start.click(
fn=self.btn_start_when_click,
inputs=[self.text_requirement, self.text_api],
outputs=[self.chatbot, self.chat_record, self.btn_start, self.text_requirement]
).then(
fn=self.btn_start_after_click,
inputs=[self.chatbot, self.chat_record],
outputs=[self.progress, self.chatbot, self.chat_record, self.chat_show, self.btn_start, self.btn_reset, self.text_requirement, self.file_show]
)
self.btn_reset.click(
fn=self.btn_reset_when_click,
inputs=[],
outputs=[self.progress, self.chatbot, self.chat_record, self.chat_show, self.btn_start, self.btn_reset, self.text_requirement, self.file_show]
).then(
fn=self.btn_reset_after_click,
inputs=[],
outputs=[self.progress, self.chatbot, self.chat_record, self.chat_show, self.btn_start, self.btn_reset, self.text_requirement, self.file_show]
)
self.file_show.select(
fn=self.file_when_select,
inputs=[self.file_show],
outputs=[self.chat_show]
)
# ===========================================
self.demo = demo
def btn_start_when_click(self, text_requirement:str, api_key:str):
"""
inputs=[self.text_requirement],
outputs=[self.chatbot, self.chat_record, self.btn_start, self.text_requirement]
"""
history = [[UIHelper.wrap_css(content=text_requirement, name="User"), None]]
yield history,\
gr.Chatbot.update(visible=True),\
gr.Button.update(interactive=False, value="Running"),\
gr.Textbox.update(value="", interactive=False)
self.send_start_cmd({'requirement': text_requirement, "api_key": api_key})
return
def btn_start_after_click(self, history:List, record):
def walk_file():
print("file:", self.cache['output_file_path'])
files = []
for _ in os.listdir(self.cache['output_file_path']):
if os.path.isfile(self.cache['output_file_path']+'/'+_):
files.append(self.cache['output_file_path']+'/'+_)
return files
"""
inputs=[self.chatbot, self.chat_record],
outputs=[self.progress, self.chatbot, self.chat_record, self.chat_show, self.btn_start, self.btn_reset, self.text_requirement, self.file_show]
"""
self.data_recorder = list()
self.data_history = list()
receive_server = self.receive_server
while True:
data_list: List = receive_server.send(None)
for item in data_list:
data = eval(item)
assert isinstance(data, list)
state, agent_name, token, node_name, node_schedule = data
assert isinstance(state, int)
fs:List = walk_file()
# 10/11/12 -> history
# 20/21/22 -> recorder
# 99 -> finish
# 30 -> register new agent
assert state in [10, 11, 12, 20, 21, 22, 99, 30]
if state == 30:
# register new agent.
gc.add_agent(eval(agent_name))
continue
if state == 99:
# finish
yield gr.HTML.update(value=self.update_progress(node_name, node_schedule)),\
history,\
gr.Chatbot.update(visible=True, value=record),\
gr.Chatbot.update(visible=True),\
gr.Button.update(visible=True, interactive=False, value="Done"),\
gr.Button.update(visible=True, interactive=True),\
gr.Textbox.update(visible=True, interactive=False),\
gr.File.update(value=fs, visible=True, interactive=True)
return
history, record = self.handle_message(history, record, state, agent_name, token, node_name)
# [self.progress, self.chatbot, self.chat_record, self.chat_show, self.btn_start, self.btn_reset, self.text_requirement, self.file_show]
yield gr.HTML.update(value=self.update_progress(node_name, node_schedule)),\
history,\
gr.Chatbot.update(visible=True, value=record),\
gr.Chatbot.update(visible=False),\
gr.Button.update(visible=True, interactive=False),\
gr.Button.update(visible=False, interactive=True),\
gr.Textbox.update(visible=True, interactive=False),\
gr.File.update(value=fs, visible=True, interactive=True)
def btn_reset_when_click(self):
"""
inputs=[],
outputs=[self.progress, self.chatbot, self.chat_record, self.chat_show, self.btn_start, self.btn_reset, self.text_requirement, self.file_show]
"""
return gr.HTML.update(value=sc.create_states(states_name=self.cache['nodes_name'])),\
gr.Chatbot.update(value=None),\
gr.Chatbot.update(value=None, visible=False),\
gr.Chatbot.update(value=None, visible=False),\
gr.Button.update(value="Restarting...", visible=True, interactive=False),\
gr.Button.update(value="Restarting...", visible=True, interactive=False),\
gr.Textbox.update(value="Restarting...", interactive=False, visible=True),\
gr.File.update(visible=False)
def btn_reset_after_click(self):
"""
inputs=[],
outputs=[self.progress, self.chatbot, self.chat_record, self.chat_show, self.btn_start, self.btn_reset, self.text_requirement, self.file_show]
"""
self.reset()
self.first_recieve_from_client(reset_mode=True)
return gr.HTML.update(value=sc.create_states(states_name=self.cache['nodes_name'])),\
gr.Chatbot.update(value=None),\
gr.Chatbot.update(value=None, visible=False),\
gr.Chatbot.update(value=None, visible=False),\
gr.Button.update(value="Start", visible=True, interactive=True),\
gr.Button.update(value="Restart", visible=False, interactive=False),\
gr.Textbox.update(value="", interactive=True, visible=True),\
gr.File.update(visible=False)
def file_when_select(self, file_obj):
"""
inputs=[self.file_show],
outputs=[self.chat_show]
"""
CODE_PREFIX = "```json\n{}\n```"
with open(file_obj.name, "r", encoding='utf-8') as f:
contents = f.readlines()
codes = "".join(contents)
return [[CODE_PREFIX.format(codes),None]]
if __name__ == '__main__':
ui = NovelUI(client_cmd=["python","gradio_backend.py"])
ui.construct_ui()
ui.run()