Spaces:
Running
Running
| from io import BytesIO | |
| import gradio as gr | |
| import pandas as pd | |
| from gradio_calendar import Calendar | |
| import datetime | |
| from avatar_generation import get_persona_avatar_bytes | |
| from gen_schedule.generate_character import gen_character | |
| import pickle | |
| from generate_scene import get_person_info_list, person_info_to_description | |
| from utils import CACHE_PERSON_PATH_LIST, MAX_DAYS, SCHEDULE_COLUMNS, create_file_for_character, read_character | |
| from PIL import Image | |
| def get_result(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name): | |
| persona = { | |
| 'name':name, | |
| 'age':int(age), | |
| 'gender':gender, | |
| 'routine':routine, | |
| 'personality':personality_list['Traits'].to_list(), | |
| 'occupation':occupation, | |
| 'thoughts':thoughts, | |
| 'lifestyle':lifestyle, | |
| } | |
| data_span = (start_date, end_date) | |
| character_dict = gen_character(persona, data_span, model_name=model_name, openai_api_key=openai_key) | |
| avatar = get_persona_avatar_bytes(persona, api_key=openai_key) | |
| # TODO delete, this code just for debug | |
| # avatar = open('avatar_default.jpg', 'rb').read() | |
| character_dict['persona']['image'] = Image.open(BytesIO(avatar)) | |
| # with open('static/exist_characters/Aarav_Patil/Aarav_Patil.pkl', 'rb') as file: | |
| # character_dict = pickle.load(file) | |
| return character_dict | |
| def schedule_to_df(schedules) -> list[pd.DataFrame]: | |
| """_summary_ | |
| Args: | |
| character_dict (_type_): _description_ | |
| Returns: | |
| list[pd.DataFrame]: each element is a day's schedule | |
| """ | |
| scheudle_dfs = [] | |
| date_str_list = [] | |
| for date, schedule in schedules.items(): | |
| df = pd.DataFrame(schedule)[SCHEDULE_COLUMNS] | |
| scheudle_dfs.append(df) | |
| date_str_list.append(date.strftime("%Y-%m-%d")) | |
| return scheudle_dfs, date_str_list | |
| def update_schedules(schedules) -> list: | |
| schedule_df_update = [gr.update(visible=False) for _ in range(MAX_DAYS)] | |
| schedule_tab_update = [gr.update(visible=False) for _ in range(MAX_DAYS)] | |
| schedule_df_list, date_str_list = schedule_to_df(schedules) | |
| for i, (df, date_str) in enumerate(zip(schedule_df_list, date_str_list)): | |
| schedule_df_update[i] = gr.update(visible=True, value=df) | |
| schedule_tab_update[i] = gr.update(visible=True, label=f"{date_str}") | |
| return schedule_df_update + schedule_tab_update | |
| def change_submit_state(): | |
| return gr.update(visible=True), gr.update(visible=False) | |
| def submit_info(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name): | |
| add_person_button_visible = gr.update(visible=False) | |
| file_update = gr.update() | |
| message = None | |
| # check the state | |
| if not name: | |
| message = "Please fill in the name of the person." | |
| elif not age: | |
| message = "Please fill in the age of the person." | |
| elif gender not in ["Male", "Female"]: | |
| message = "Please select the gender from Male or Female" | |
| elif len(personality_list) == 0: | |
| message = "Please fill in the personality traits of the person." | |
| # start date must be before end date and the period can not exceed MAX_DAYS days | |
| elif start_date > end_date: | |
| message = "Start date must be before end date." | |
| elif (end_date - start_date).days > MAX_DAYS: | |
| message = f"The period can not exceed {MAX_DAYS} days." | |
| elif not routine: | |
| message = "Please fill in the routine of the person." | |
| elif not occupation: | |
| message = "Please fill in the occupation of the person." | |
| elif not thoughts: | |
| message = "Please fill in the thoughts of the person." | |
| elif not lifestyle: | |
| message = "Please fill in the lifestyle of the person." | |
| elif not openai_key: | |
| message = "Please fill in the OpenAI API Key." | |
| elif model_name not in ["gpt3.5", "gpt4"]: | |
| message = "Please select the model name from gpt3.5 or gpt4." | |
| if message is None: | |
| character_dict = get_result(name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name) | |
| file_path = create_file_for_character(character_dict) | |
| file_update = gr.update(visible=True, value=str(file_path)) | |
| add_person_button_visible = gr.update(visible=True) | |
| message = "Character created successfully!" | |
| _, _, schedules = read_character(character_dict) | |
| schedule_updates = update_schedules(schedules) | |
| else: | |
| schedule_updates = [gr.update(visible=False) for _ in range(MAX_DAYS * 2)] | |
| return [ | |
| message, file_update, add_person_button_visible | |
| ] + schedule_updates | |
| def reset(): | |
| return "", 0, "", pd.DataFrame(columns=["Traits"], index=[0, 1, 2]), \ | |
| "", "", "", "", datetime.datetime.now(), datetime.datetime.now(), "", "", \ | |
| gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ | |
| *[gr.update(visible=False) for _ in range(MAX_DAYS)] | |
| def init_display_person_dataframe(): | |
| schedule_df_list = [] | |
| schedule_tab_list = [] | |
| with gr.Blocks(): | |
| for day_index in range(MAX_DAYS): | |
| with gr.Tab(label=f"Day {day_index}", visible=False) as tab: | |
| schedule_df_list.append(gr.Dataframe()) | |
| schedule_tab_list.append(tab) | |
| return schedule_df_list, schedule_tab_list | |
| def visualize_person_info(): | |
| person_file_list: list[str|dict] = CACHE_PERSON_PATH_LIST | |
| person_info_list = get_person_info_list(person_file_list) | |
| person_name_to_file = {person_info['name']: person_file for person_info, person_file in zip(person_info_list, person_file_list)} | |
| person_name_to_info = {person_info['name']: person_info for person_info in person_info_list} | |
| person_name_to_description = {person_info['name']: person_info_to_description(person_info) for person_info in person_info_list} | |
| person_name_to_file = gr.State(person_name_to_file) | |
| person_name_to_info = gr.State(person_name_to_info) | |
| person_name_to_description = gr.State(person_name_to_description) | |
| gr.Markdown("### Fill in the basic info of the person you want to create") | |
| with gr.Row(): | |
| name = gr.Textbox(label="Name") | |
| age = gr.Number(label="Age", precision=0, minimum=0) | |
| gender = gr.Dropdown(label="Gender", choices=["Male", "Female"]) | |
| gr.Markdown("### Fill in the personality traits of the person") | |
| personality_list = gr.Dataframe(col_count=1, row_count=3, label=None, headers=["Traits"]) | |
| gr.Markdown("### Fill in the routine, occupation, thoughts, and lifestyle of the person") | |
| routine = gr.Textbox(label="Routine", placeholder="e.g. Aarav typically sleeps by 10pm and wakes up at 6am for school.") | |
| occupation = gr.Textbox(label="Occupation", placeholder="e.g. Aarav is a high school student who is preparing for university entrance exams.") | |
| thoughts = gr.TextArea(label="Thoughts", placeholder="e.g. Lately, he's been worried about his chemistry exam on February 2nd, 2024, and has been studying extra hours after school to prepare.") | |
| lifestyle = gr.TextArea(label="Lifestyle", placeholder="e.g. Aarav follows a vegetarian Indian diet and practices yoga every evening. He prefers to study in a quiet room and enjoys taking long walks as a study break. His room is equipped with minimal technology, as he tries to avoid distractions.") | |
| gr.Markdown("### Select the dates for the person's lifestyle:") | |
| with gr.Row(): | |
| start_date = Calendar(type="datetime", label="Start Date") | |
| end_date = Calendar(type="datetime", label="End Date") | |
| gr.Textbox(f"The period can not exceed {MAX_DAYS} days.", interactive=False, label="Note", show_copy_button=False) | |
| start_date.change( | |
| fn=lambda x: x + datetime.timedelta(days=MAX_DAYS), | |
| inputs=[start_date], outputs=[end_date] | |
| ) | |
| gr.Markdown("### Add openai settings to generate the character") | |
| openai_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API Key") | |
| model_name = gr.Dropdown(label="Model Name", choices=["gpt3.5", "gpt4"]) | |
| submit_btn = gr.Button("Submit") | |
| result = gr.Textbox(label="Result", lines=2, visible=False, interactive=False) | |
| schedule_df_list, schedule_tab_list = init_display_person_dataframe() | |
| character_file = gr.File(label="Download Character", visible=False) | |
| add_person_btn = gr.Button("Add Character", visible=False) | |
| clear_btn = gr.Button("Clear") | |
| submit_btn.click( | |
| change_submit_state, | |
| outputs=[result, submit_btn] | |
| ) | |
| submit_btn.click( | |
| submit_info, | |
| inputs=[name, age, gender, personality_list, routine, occupation, thoughts, lifestyle, start_date, end_date, openai_key, model_name], | |
| outputs=[result, character_file, add_person_btn] + schedule_df_list + schedule_tab_list, | |
| ) | |
| clear_btn.click( | |
| fn=reset, | |
| outputs=[ | |
| name, age, gender, personality_list, routine, | |
| occupation, thoughts, lifestyle, start_date, end_date, | |
| openai_key, model_name, submit_btn, result, character_file, add_person_btn | |
| ] + schedule_tab_list | |
| ) | |
| return add_person_btn, character_file, person_name_to_file, person_name_to_info, person_name_to_description | |
| if __name__ == "__main__": | |
| with gr.Blocks() as demo: | |
| visualize_person_info() | |
| demo.launch() | |
| # data_path = 'static/exist_characters/Aarav_Patil/Aarav_Patil.pkl' | |
| # import pickle | |
| # data = pickle.load(open(data_path, 'rb')) | |
| # persona = data['persona'] | |
| # dates = list(data['schedule']['default'].keys()) | |
| # start_date = dates[0] | |
| # end_date = dates[-1] | |
| # character_dict = gen_character(persona, (start_date, start_date), model_name="gpt3", openai_api_key="") | |
| # # from IPython import embed; embed() | |