DynamicSceneGeneration / generate_person.py
wcx
Init commit
8f3b56b
from io import BytesIO
import gradio as gr
import pandas as pd
from gradio_calendar import Calendar
import datetime
from avatar_generation import get_persona_avatar_bytes
from gen_schedule.generate_character import gen_character
import pickle
from generate_scene import get_person_info_list, person_info_to_description
from utils import CACHE_PERSON_PATH_LIST, MAX_DAYS, SCHEDULE_COLUMNS, create_file_for_character, read_character
from PIL import Image
def get_result(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name):
persona = {
'name':name,
'age':int(age),
'gender':gender,
'routine':routine,
'personality':personality_list['Traits'].to_list(),
'occupation':occupation,
'thoughts':thoughts,
'lifestyle':lifestyle,
}
data_span = (start_date, end_date)
character_dict = gen_character(persona, data_span, model_name=model_name, openai_api_key=openai_key)
avatar = get_persona_avatar_bytes(persona, api_key=openai_key)
# TODO delete, this code just for debug
# avatar = open('avatar_default.jpg', 'rb').read()
character_dict['persona']['image'] = Image.open(BytesIO(avatar))
# with open('static/exist_characters/Aarav_Patil/Aarav_Patil.pkl', 'rb') as file:
# character_dict = pickle.load(file)
return character_dict
def schedule_to_df(schedules) -> list[pd.DataFrame]:
"""_summary_
Args:
character_dict (_type_): _description_
Returns:
list[pd.DataFrame]: each element is a day's schedule
"""
scheudle_dfs = []
date_str_list = []
for date, schedule in schedules.items():
df = pd.DataFrame(schedule)[SCHEDULE_COLUMNS]
scheudle_dfs.append(df)
date_str_list.append(date.strftime("%Y-%m-%d"))
return scheudle_dfs, date_str_list
def update_schedules(schedules) -> list:
schedule_df_update = [gr.update(visible=False) for _ in range(MAX_DAYS)]
schedule_tab_update = [gr.update(visible=False) for _ in range(MAX_DAYS)]
schedule_df_list, date_str_list = schedule_to_df(schedules)
for i, (df, date_str) in enumerate(zip(schedule_df_list, date_str_list)):
schedule_df_update[i] = gr.update(visible=True, value=df)
schedule_tab_update[i] = gr.update(visible=True, label=f"{date_str}")
return schedule_df_update + schedule_tab_update
def change_submit_state():
return gr.update(visible=True), gr.update(visible=False)
def submit_info(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name):
add_person_button_visible = gr.update(visible=False)
file_update = gr.update()
message = None
# check the state
if not name:
message = "Please fill in the name of the person."
elif not age:
message = "Please fill in the age of the person."
elif gender not in ["Male", "Female"]:
message = "Please select the gender from Male or Female"
elif len(personality_list) == 0:
message = "Please fill in the personality traits of the person."
# start date must be before end date and the period can not exceed MAX_DAYS days
elif start_date > end_date:
message = "Start date must be before end date."
elif (end_date - start_date).days > MAX_DAYS:
message = f"The period can not exceed {MAX_DAYS} days."
elif not routine:
message = "Please fill in the routine of the person."
elif not occupation:
message = "Please fill in the occupation of the person."
elif not thoughts:
message = "Please fill in the thoughts of the person."
elif not lifestyle:
message = "Please fill in the lifestyle of the person."
elif not openai_key:
message = "Please fill in the OpenAI API Key."
elif model_name not in ["gpt3.5", "gpt4"]:
message = "Please select the model name from gpt3.5 or gpt4."
if message is None:
character_dict = get_result(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name)
file_path = create_file_for_character(character_dict)
file_update = gr.update(visible=True, value=str(file_path))
add_person_button_visible = gr.update(visible=True)
message = "Character created successfully!"
_, _, schedules = read_character(character_dict)
schedule_updates = update_schedules(schedules)
else:
schedule_updates = [gr.update(visible=False) for _ in range(MAX_DAYS * 2)]
return [
message, file_update, add_person_button_visible
] + schedule_updates
def reset():
return "", 0, "", pd.DataFrame(columns=["Traits"], index=[0, 1, 2]), \
"", "", "", "", datetime.datetime.now(), datetime.datetime.now(), "", "", \
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
*[gr.update(visible=False) for _ in range(MAX_DAYS)]
def init_display_person_dataframe():
schedule_df_list = []
schedule_tab_list = []
with gr.Blocks():
for day_index in range(MAX_DAYS):
with gr.Tab(label=f"Day {day_index}", visible=False) as tab:
schedule_df_list.append(gr.Dataframe())
schedule_tab_list.append(tab)
return schedule_df_list, schedule_tab_list
def visualize_person_info():
person_file_list: list[str|dict] = CACHE_PERSON_PATH_LIST
person_info_list = get_person_info_list(person_file_list)
person_name_to_file = {person_info['name']: person_file for person_info, person_file in zip(person_info_list, person_file_list)}
person_name_to_info = {person_info['name']: person_info for person_info in person_info_list}
person_name_to_description = {person_info['name']: person_info_to_description(person_info) for person_info in person_info_list}
person_name_to_file = gr.State(person_name_to_file)
person_name_to_info = gr.State(person_name_to_info)
person_name_to_description = gr.State(person_name_to_description)
gr.Markdown("### Fill in the basic info of the person you want to create")
with gr.Row():
name = gr.Textbox(label="Name")
age = gr.Number(label="Age", precision=0, minimum=0)
gender = gr.Dropdown(label="Gender", choices=["Male", "Female"])
gr.Markdown("### Fill in the personality traits of the person")
personality_list = gr.Dataframe(col_count=1, row_count=3, label=None, headers=["Traits"])
gr.Markdown("### Fill in the routine, occupation, thoughts, and lifestyle of the person")
routine = gr.Textbox(label="Routine", placeholder="e.g. Aarav typically sleeps by 10pm and wakes up at 6am for school.")
occupation = gr.Textbox(label="Occupation", placeholder="e.g. Aarav is a high school student who is preparing for university entrance exams.")
thoughts = gr.TextArea(label="Thoughts", placeholder="e.g. Lately, he's been worried about his chemistry exam on February 2nd, 2024, and has been studying extra hours after school to prepare.")
lifestyle = gr.TextArea(label="Lifestyle", placeholder="e.g. Aarav follows a vegetarian Indian diet and practices yoga every evening. He prefers to study in a quiet room and enjoys taking long walks as a study break. His room is equipped with minimal technology, as he tries to avoid distractions.")
gr.Markdown("### Select the dates for the person's lifestyle:")
with gr.Row():
start_date = Calendar(type="datetime", label="Start Date")
end_date = Calendar(type="datetime", label="End Date")
gr.Textbox(f"The period can not exceed {MAX_DAYS} days.", interactive=False, label="Note", show_copy_button=False)
start_date.change(
fn=lambda x: x + datetime.timedelta(days=MAX_DAYS),
inputs=[start_date], outputs=[end_date]
)
gr.Markdown("### Add openai settings to generate the character")
openai_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API Key")
model_name = gr.Dropdown(label="Model Name", choices=["gpt3.5", "gpt4"])
submit_btn = gr.Button("Submit")
result = gr.Textbox(label="Result", lines=2, visible=False, interactive=False)
schedule_df_list, schedule_tab_list = init_display_person_dataframe()
character_file = gr.File(label="Download Character", visible=False)
add_person_btn = gr.Button("Add Character", visible=False)
clear_btn = gr.Button("Clear")
submit_btn.click(
change_submit_state,
outputs=[result, submit_btn]
)
submit_btn.click(
submit_info,
inputs=[name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name],
outputs=[result, character_file, add_person_btn] + schedule_df_list + schedule_tab_list,
)
clear_btn.click(
fn=reset,
outputs=[
name, age, gender, personality_list, routine,
occupation, thoughts, lifestyle, start_date, end_date,
openai_key, model_name, submit_btn, result, character_file, add_person_btn
] + schedule_tab_list
)
return add_person_btn, character_file, person_name_to_file, person_name_to_info, person_name_to_description
if __name__ == "__main__":
with gr.Blocks() as demo:
visualize_person_info()
demo.launch()
# data_path = 'static/exist_characters/Aarav_Patil/Aarav_Patil.pkl'
# import pickle
# data = pickle.load(open(data_path, 'rb'))
# persona = data['persona']
# dates = list(data['schedule']['default'].keys())
# start_date = dates[0]
# end_date = dates[-1]
# character_dict = gen_character(persona, (start_date, start_date), model_name="gpt3", openai_api_key="")
# # from IPython import embed; embed()