Spaces:
Running
Running
from io import BytesIO | |
import gradio as gr | |
import pandas as pd | |
from gradio_calendar import Calendar | |
import datetime | |
from avatar_generation import get_persona_avatar_bytes | |
from gen_schedule.generate_character import gen_character | |
import pickle | |
from generate_scene import get_person_info_list, person_info_to_description | |
from utils import CACHE_PERSON_PATH_LIST, MAX_DAYS, SCHEDULE_COLUMNS, create_file_for_character, read_character | |
from PIL import Image | |
def get_result(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name): | |
persona = { | |
'name':name, | |
'age':int(age), | |
'gender':gender, | |
'routine':routine, | |
'personality':personality_list['Traits'].to_list(), | |
'occupation':occupation, | |
'thoughts':thoughts, | |
'lifestyle':lifestyle, | |
} | |
data_span = (start_date, end_date) | |
character_dict = gen_character(persona, data_span, model_name=model_name, openai_api_key=openai_key) | |
avatar = get_persona_avatar_bytes(persona, api_key=openai_key) | |
# TODO delete, this code just for debug | |
# avatar = open('avatar_default.jpg', 'rb').read() | |
character_dict['persona']['image'] = Image.open(BytesIO(avatar)) | |
# with open('static/exist_characters/Aarav_Patil/Aarav_Patil.pkl', 'rb') as file: | |
# character_dict = pickle.load(file) | |
return character_dict | |
def schedule_to_df(schedules) -> list[pd.DataFrame]: | |
"""_summary_ | |
Args: | |
character_dict (_type_): _description_ | |
Returns: | |
list[pd.DataFrame]: each element is a day's schedule | |
""" | |
scheudle_dfs = [] | |
date_str_list = [] | |
for date, schedule in schedules.items(): | |
df = pd.DataFrame(schedule)[SCHEDULE_COLUMNS] | |
scheudle_dfs.append(df) | |
date_str_list.append(date.strftime("%Y-%m-%d")) | |
return scheudle_dfs, date_str_list | |
def update_schedules(schedules) -> list: | |
schedule_df_update = [gr.update(visible=False) for _ in range(MAX_DAYS)] | |
schedule_tab_update = [gr.update(visible=False) for _ in range(MAX_DAYS)] | |
schedule_df_list, date_str_list = schedule_to_df(schedules) | |
for i, (df, date_str) in enumerate(zip(schedule_df_list, date_str_list)): | |
schedule_df_update[i] = gr.update(visible=True, value=df) | |
schedule_tab_update[i] = gr.update(visible=True, label=f"{date_str}") | |
return schedule_df_update + schedule_tab_update | |
def change_submit_state(): | |
return gr.update(visible=True), gr.update(visible=False) | |
def submit_info(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name): | |
add_person_button_visible = gr.update(visible=False) | |
file_update = gr.update() | |
message = None | |
# check the state | |
if not name: | |
message = "Please fill in the name of the person." | |
elif not age: | |
message = "Please fill in the age of the person." | |
elif gender not in ["Male", "Female"]: | |
message = "Please select the gender from Male or Female" | |
elif len(personality_list) == 0: | |
message = "Please fill in the personality traits of the person." | |
# start date must be before end date and the period can not exceed MAX_DAYS days | |
elif start_date > end_date: | |
message = "Start date must be before end date." | |
elif (end_date - start_date).days > MAX_DAYS: | |
message = f"The period can not exceed {MAX_DAYS} days." | |
elif not routine: | |
message = "Please fill in the routine of the person." | |
elif not occupation: | |
message = "Please fill in the occupation of the person." | |
elif not thoughts: | |
message = "Please fill in the thoughts of the person." | |
elif not lifestyle: | |
message = "Please fill in the lifestyle of the person." | |
elif not openai_key: | |
message = "Please fill in the OpenAI API Key." | |
elif model_name not in ["gpt3.5", "gpt4"]: | |
message = "Please select the model name from gpt3.5 or gpt4." | |
if message is None: | |
character_dict = get_result(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name) | |
file_path = create_file_for_character(character_dict) | |
file_update = gr.update(visible=True, value=str(file_path)) | |
add_person_button_visible = gr.update(visible=True) | |
message = "Character created successfully!" | |
_, _, schedules = read_character(character_dict) | |
schedule_updates = update_schedules(schedules) | |
else: | |
schedule_updates = [gr.update(visible=False) for _ in range(MAX_DAYS * 2)] | |
return [ | |
message, file_update, add_person_button_visible | |
] + schedule_updates | |
def reset(): | |
return "", 0, "", pd.DataFrame(columns=["Traits"], index=[0, 1, 2]), \ | |
"", "", "", "", datetime.datetime.now(), datetime.datetime.now(), "", "", \ | |
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ | |
*[gr.update(visible=False) for _ in range(MAX_DAYS)] | |
def init_display_person_dataframe(): | |
schedule_df_list = [] | |
schedule_tab_list = [] | |
with gr.Blocks(): | |
for day_index in range(MAX_DAYS): | |
with gr.Tab(label=f"Day {day_index}", visible=False) as tab: | |
schedule_df_list.append(gr.Dataframe()) | |
schedule_tab_list.append(tab) | |
return schedule_df_list, schedule_tab_list | |
def visualize_person_info(): | |
person_file_list: list[str|dict] = CACHE_PERSON_PATH_LIST | |
person_info_list = get_person_info_list(person_file_list) | |
person_name_to_file = {person_info['name']: person_file for person_info, person_file in zip(person_info_list, person_file_list)} | |
person_name_to_info = {person_info['name']: person_info for person_info in person_info_list} | |
person_name_to_description = {person_info['name']: person_info_to_description(person_info) for person_info in person_info_list} | |
person_name_to_file = gr.State(person_name_to_file) | |
person_name_to_info = gr.State(person_name_to_info) | |
person_name_to_description = gr.State(person_name_to_description) | |
gr.Markdown("### Fill in the basic info of the person you want to create") | |
with gr.Row(): | |
name = gr.Textbox(label="Name") | |
age = gr.Number(label="Age", precision=0, minimum=0) | |
gender = gr.Dropdown(label="Gender", choices=["Male", "Female"]) | |
gr.Markdown("### Fill in the personality traits of the person") | |
personality_list = gr.Dataframe(col_count=1, row_count=3, label=None, headers=["Traits"]) | |
gr.Markdown("### Fill in the routine, occupation, thoughts, and lifestyle of the person") | |
routine = gr.Textbox(label="Routine", placeholder="e.g. Aarav typically sleeps by 10pm and wakes up at 6am for school.") | |
occupation = gr.Textbox(label="Occupation", placeholder="e.g. Aarav is a high school student who is preparing for university entrance exams.") | |
thoughts = gr.TextArea(label="Thoughts", placeholder="e.g. Lately, he's been worried about his chemistry exam on February 2nd, 2024, and has been studying extra hours after school to prepare.") | |
lifestyle = gr.TextArea(label="Lifestyle", placeholder="e.g. Aarav follows a vegetarian Indian diet and practices yoga every evening. He prefers to study in a quiet room and enjoys taking long walks as a study break. His room is equipped with minimal technology, as he tries to avoid distractions.") | |
gr.Markdown("### Select the dates for the person's lifestyle:") | |
with gr.Row(): | |
start_date = Calendar(type="datetime", label="Start Date") | |
end_date = Calendar(type="datetime", label="End Date") | |
gr.Textbox(f"The period can not exceed {MAX_DAYS} days.", interactive=False, label="Note", show_copy_button=False) | |
start_date.change( | |
fn=lambda x: x + datetime.timedelta(days=MAX_DAYS), | |
inputs=[start_date], outputs=[end_date] | |
) | |
gr.Markdown("### Add openai settings to generate the character") | |
openai_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API Key") | |
model_name = gr.Dropdown(label="Model Name", choices=["gpt3.5", "gpt4"]) | |
submit_btn = gr.Button("Submit") | |
result = gr.Textbox(label="Result", lines=2, visible=False, interactive=False) | |
schedule_df_list, schedule_tab_list = init_display_person_dataframe() | |
character_file = gr.File(label="Download Character", visible=False) | |
add_person_btn = gr.Button("Add Character", visible=False) | |
clear_btn = gr.Button("Clear") | |
submit_btn.click( | |
change_submit_state, | |
outputs=[result, submit_btn] | |
) | |
submit_btn.click( | |
submit_info, | |
inputs=[name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name], | |
outputs=[result, character_file, add_person_btn] + schedule_df_list + schedule_tab_list, | |
) | |
clear_btn.click( | |
fn=reset, | |
outputs=[ | |
name, age, gender, personality_list, routine, | |
occupation, thoughts, lifestyle, start_date, end_date, | |
openai_key, model_name, submit_btn, result, character_file, add_person_btn | |
] + schedule_tab_list | |
) | |
return add_person_btn, character_file, person_name_to_file, person_name_to_info, person_name_to_description | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
visualize_person_info() | |
demo.launch() | |
# data_path = 'static/exist_characters/Aarav_Patil/Aarav_Patil.pkl' | |
# import pickle | |
# data = pickle.load(open(data_path, 'rb')) | |
# persona = data['persona'] | |
# dates = list(data['schedule']['default'].keys()) | |
# start_date = dates[0] | |
# end_date = dates[-1] | |
# character_dict = gen_character(persona, (start_date, start_date), model_name="gpt3", openai_api_key="") | |
# # from IPython import embed; embed() | |