Spaces:
Running
Running
File size: 9,941 Bytes
8f3b56b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
from io import BytesIO
import gradio as gr
import pandas as pd
from gradio_calendar import Calendar
import datetime
from avatar_generation import get_persona_avatar_bytes
from gen_schedule.generate_character import gen_character
import pickle
from generate_scene import get_person_info_list, person_info_to_description
from utils import CACHE_PERSON_PATH_LIST, MAX_DAYS, SCHEDULE_COLUMNS, create_file_for_character, read_character
from PIL import Image
def get_result(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name):
persona = {
'name':name,
'age':int(age),
'gender':gender,
'routine':routine,
'personality':personality_list['Traits'].to_list(),
'occupation':occupation,
'thoughts':thoughts,
'lifestyle':lifestyle,
}
data_span = (start_date, end_date)
character_dict = gen_character(persona, data_span, model_name=model_name, openai_api_key=openai_key)
avatar = get_persona_avatar_bytes(persona, api_key=openai_key)
# TODO delete, this code just for debug
# avatar = open('avatar_default.jpg', 'rb').read()
character_dict['persona']['image'] = Image.open(BytesIO(avatar))
# with open('static/exist_characters/Aarav_Patil/Aarav_Patil.pkl', 'rb') as file:
# character_dict = pickle.load(file)
return character_dict
def schedule_to_df(schedules) -> list[pd.DataFrame]:
"""_summary_
Args:
character_dict (_type_): _description_
Returns:
list[pd.DataFrame]: each element is a day's schedule
"""
scheudle_dfs = []
date_str_list = []
for date, schedule in schedules.items():
df = pd.DataFrame(schedule)[SCHEDULE_COLUMNS]
scheudle_dfs.append(df)
date_str_list.append(date.strftime("%Y-%m-%d"))
return scheudle_dfs, date_str_list
def update_schedules(schedules) -> list:
schedule_df_update = [gr.update(visible=False) for _ in range(MAX_DAYS)]
schedule_tab_update = [gr.update(visible=False) for _ in range(MAX_DAYS)]
schedule_df_list, date_str_list = schedule_to_df(schedules)
for i, (df, date_str) in enumerate(zip(schedule_df_list, date_str_list)):
schedule_df_update[i] = gr.update(visible=True, value=df)
schedule_tab_update[i] = gr.update(visible=True, label=f"{date_str}")
return schedule_df_update + schedule_tab_update
def change_submit_state():
return gr.update(visible=True), gr.update(visible=False)
def submit_info(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name):
add_person_button_visible = gr.update(visible=False)
file_update = gr.update()
message = None
# check the state
if not name:
message = "Please fill in the name of the person."
elif not age:
message = "Please fill in the age of the person."
elif gender not in ["Male", "Female"]:
message = "Please select the gender from Male or Female"
elif len(personality_list) == 0:
message = "Please fill in the personality traits of the person."
# start date must be before end date and the period can not exceed MAX_DAYS days
elif start_date > end_date:
message = "Start date must be before end date."
elif (end_date - start_date).days > MAX_DAYS:
message = f"The period can not exceed {MAX_DAYS} days."
elif not routine:
message = "Please fill in the routine of the person."
elif not occupation:
message = "Please fill in the occupation of the person."
elif not thoughts:
message = "Please fill in the thoughts of the person."
elif not lifestyle:
message = "Please fill in the lifestyle of the person."
elif not openai_key:
message = "Please fill in the OpenAI API Key."
elif model_name not in ["gpt3.5", "gpt4"]:
message = "Please select the model name from gpt3.5 or gpt4."
if message is None:
character_dict = get_result(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name)
file_path = create_file_for_character(character_dict)
file_update = gr.update(visible=True, value=str(file_path))
add_person_button_visible = gr.update(visible=True)
message = "Character created successfully!"
_, _, schedules = read_character(character_dict)
schedule_updates = update_schedules(schedules)
else:
schedule_updates = [gr.update(visible=False) for _ in range(MAX_DAYS * 2)]
return [
message, file_update, add_person_button_visible
] + schedule_updates
def reset():
return "", 0, "", pd.DataFrame(columns=["Traits"], index=[0, 1, 2]), \
"", "", "", "", datetime.datetime.now(), datetime.datetime.now(), "", "", \
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
*[gr.update(visible=False) for _ in range(MAX_DAYS)]
def init_display_person_dataframe():
schedule_df_list = []
schedule_tab_list = []
with gr.Blocks():
for day_index in range(MAX_DAYS):
with gr.Tab(label=f"Day {day_index}", visible=False) as tab:
schedule_df_list.append(gr.Dataframe())
schedule_tab_list.append(tab)
return schedule_df_list, schedule_tab_list
def visualize_person_info():
person_file_list: list[str|dict] = CACHE_PERSON_PATH_LIST
person_info_list = get_person_info_list(person_file_list)
person_name_to_file = {person_info['name']: person_file for person_info, person_file in zip(person_info_list, person_file_list)}
person_name_to_info = {person_info['name']: person_info for person_info in person_info_list}
person_name_to_description = {person_info['name']: person_info_to_description(person_info) for person_info in person_info_list}
person_name_to_file = gr.State(person_name_to_file)
person_name_to_info = gr.State(person_name_to_info)
person_name_to_description = gr.State(person_name_to_description)
gr.Markdown("### Fill in the basic info of the person you want to create")
with gr.Row():
name = gr.Textbox(label="Name")
age = gr.Number(label="Age", precision=0, minimum=0)
gender = gr.Dropdown(label="Gender", choices=["Male", "Female"])
gr.Markdown("### Fill in the personality traits of the person")
personality_list = gr.Dataframe(col_count=1, row_count=3, label=None, headers=["Traits"])
gr.Markdown("### Fill in the routine, occupation, thoughts, and lifestyle of the person")
routine = gr.Textbox(label="Routine", placeholder="e.g. Aarav typically sleeps by 10pm and wakes up at 6am for school.")
occupation = gr.Textbox(label="Occupation", placeholder="e.g. Aarav is a high school student who is preparing for university entrance exams.")
thoughts = gr.TextArea(label="Thoughts", placeholder="e.g. Lately, he's been worried about his chemistry exam on February 2nd, 2024, and has been studying extra hours after school to prepare.")
lifestyle = gr.TextArea(label="Lifestyle", placeholder="e.g. Aarav follows a vegetarian Indian diet and practices yoga every evening. He prefers to study in a quiet room and enjoys taking long walks as a study break. His room is equipped with minimal technology, as he tries to avoid distractions.")
gr.Markdown("### Select the dates for the person's lifestyle:")
with gr.Row():
start_date = Calendar(type="datetime", label="Start Date")
end_date = Calendar(type="datetime", label="End Date")
gr.Textbox(f"The period can not exceed {MAX_DAYS} days.", interactive=False, label="Note", show_copy_button=False)
start_date.change(
fn=lambda x: x + datetime.timedelta(days=MAX_DAYS),
inputs=[start_date], outputs=[end_date]
)
gr.Markdown("### Add openai settings to generate the character")
openai_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API Key")
model_name = gr.Dropdown(label="Model Name", choices=["gpt3.5", "gpt4"])
submit_btn = gr.Button("Submit")
result = gr.Textbox(label="Result", lines=2, visible=False, interactive=False)
schedule_df_list, schedule_tab_list = init_display_person_dataframe()
character_file = gr.File(label="Download Character", visible=False)
add_person_btn = gr.Button("Add Character", visible=False)
clear_btn = gr.Button("Clear")
submit_btn.click(
change_submit_state,
outputs=[result, submit_btn]
)
submit_btn.click(
submit_info,
inputs=[name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name],
outputs=[result, character_file, add_person_btn] + schedule_df_list + schedule_tab_list,
)
clear_btn.click(
fn=reset,
outputs=[
name, age, gender, personality_list, routine,
occupation, thoughts, lifestyle, start_date, end_date,
openai_key, model_name, submit_btn, result, character_file, add_person_btn
] + schedule_tab_list
)
return add_person_btn, character_file, person_name_to_file, person_name_to_info, person_name_to_description
if __name__ == "__main__":
with gr.Blocks() as demo:
visualize_person_info()
demo.launch()
# data_path = 'static/exist_characters/Aarav_Patil/Aarav_Patil.pkl'
# import pickle
# data = pickle.load(open(data_path, 'rb'))
# persona = data['persona']
# dates = list(data['schedule']['default'].keys())
# start_date = dates[0]
# end_date = dates[-1]
# character_dict = gen_character(persona, (start_date, start_date), model_name="gpt3", openai_api_key="")
# # from IPython import embed; embed()
|