Spaces:
Running
Running
import json | |
import gradio as gr | |
from pathlib import Path | |
import os | |
import pickle | |
from constants import OBJECTS, RECEPTACLES | |
import pandas as pd | |
from generate_video import generate_video | |
from utils import * | |
def get_scene_info(scene_file_name: str) -> tuple[str, str]: | |
""" | |
Return: | |
scene_image_path: str, the path of the scene image | |
markdown_description: str, the markdown description of the scene | |
""" | |
scene_dir = get_scene_dir_path(scene_file_name) | |
scene_image_path = get_scene_image_path(scene_file_name) | |
scene_data = get_scene_data(scene_dir) | |
all_object_type_list = [object_info['id'].split('|', 1)[0] for object_info in scene_data['objects']] | |
room_counter = get_room_counter(scene_data) | |
receptacle_type_list = [object_type for object_type in all_object_type_list if object_type in RECEPTACLES] | |
object_type_list = [object_type for object_type in all_object_type_list if object_type in OBJECTS - RECEPTACLES] | |
receptacle_counter = {receptacle_type: receptacle_type_list.count(receptacle_type) for receptacle_type in set(receptacle_type_list)} | |
object_type_couter = {object_type: object_type_list.count(object_type) for object_type in set(object_type_list)} | |
receptacle_type_items = list(receptacle_counter.items()) | |
object_type_items = list(object_type_couter.items()) | |
receptacle_type_items.sort(key=lambda x: x[1], reverse=True) | |
object_type_items.sort(key=lambda x: x[1], reverse=True) | |
receptacle_type_items = receptacle_type_items[:10] | |
object_type_items = object_type_items[:10] | |
receptacle_type_items += [('' , '')] * (10 - len(receptacle_type_items)) | |
object_type_items += [('' , '')] * (10 - len(object_type_items)) | |
object_type_couter_df = pd.DataFrame(object_type_items, columns=['Object Type', 'Count']) | |
object_type_couter_df = object_type_couter_df.reset_index(drop=True) | |
receptacle_counter_df = pd.DataFrame(receptacle_type_items, columns=['Receptacle Type', 'Count']) | |
receptacle_counter_df = receptacle_counter_df.reset_index(drop=True) | |
recetpace_number = len(receptacle_type_list) | |
object_number = len(object_type_list) | |
# room type can be like 4 Bedroom, 1 Living Room, 1 Kitchen, 1 Bathroom | |
markdown_description = f"""Scene Information | |
- Number of Rooms: {sum(room_counter.values())} | |
- Room Types: {', '.join([f'{count} {room_type}' for room_type, count in room_counter.items()])} | |
- Number of Objects: {object_number} | |
- Number of Receptacles: {recetpace_number} | |
""" | |
return scene_image_path, markdown_description, receptacle_counter_df, object_type_couter_df | |
def visualize_scene(): | |
# 使用 Grid 布局组件来组织界面 | |
default_scene_file_name = SCENE_FILE_NAME_LIST[0] | |
default_scene_image_path, default_text, default_receptacle_table, default_object_table = get_scene_info(default_scene_file_name) | |
with gr.Row(): | |
dropdown = gr.Dropdown(choices=SCENE_FILE_NAME_LIST, label="Select Scene ID", value=default_scene_file_name) | |
with gr.Row(equal_height=True): | |
image = gr.Image(label="Scene Overhead View", show_label=False, value=default_scene_image_path) | |
with gr.Column(): | |
text = gr.Textbox(label="Scene Statistics", value=default_text, lines=6) | |
with gr.Row(equal_height=True): | |
receptacle_table = gr.Dataframe(label="Receptacle Type Count", height=520, value=default_receptacle_table) | |
object_table = gr.Dataframe(label="Object Type Count", height=520, value=default_object_table) | |
dropdown.change(fn=get_scene_info, inputs=dropdown, outputs=[image, text, receptacle_table, object_table]) | |
return dropdown | |
PERSON_TABLE_COLUMNS = ['name', 'age', 'gender', 'personality', 'routine', 'occupation', 'thoughts', 'lifestyle'] | |
TEMPLATE_ROW_NUMBER = 10 | |
PERSON_ELEMENT_NUM = 6 | |
def get_person_info_list(person_file_list: list[str|Path|dict]): | |
person_info_list = [] | |
for person_file in person_file_list: | |
if isinstance(person_file, str) or isinstance(person_file, Path): | |
with open(person_file, 'rb') as f: | |
person_info = pickle.load(f) | |
person_info = person_info['persona'] | |
if 'image' not in person_info: | |
image_path = os.path.join(os.path.dirname(person_file), 'avatar.jpg') | |
person_info['image'] = image_path | |
else: | |
person_info = person_file['persona'] | |
person_info_list.append(person_info) | |
return person_info_list | |
def person_info_to_description(person_info): | |
return f"{person_info['name']}, a {', '.join(person_info['personality'])} {person_info['age']} years old {'man' if person_info['gender'] == 'Male' else 'woman'}." | |
def person_info_to_elements(person_info): | |
return ( | |
person_info['image'], | |
f"**Name:** {person_info['name']}", | |
f"**Age:** {person_info['age']}", | |
f"**Gender:** {person_info['gender']}", | |
f"**Personality:** {', '.join(person_info['personality'])}", | |
f"**Routine:** {person_info['routine']}" | |
) | |
def get_person_elements_from_row_elements(elements: list[str], row_index: int): | |
return elements[row_index * PERSON_ELEMENT_NUM: (row_index + 1) * PERSON_ELEMENT_NUM] | |
def get_person_name_from_row_elements(row_elements: list[str]): | |
return row_elements[1].replace('**Name:** ', '') | |
def get_person_dataframe(row_index_to_person_name: list[str], person_name_to_info: dict[str, dict]): | |
if len(row_index_to_person_name) == 0: | |
return pd.DataFrame(columns=PERSON_TABLE_COLUMNS) | |
else: | |
return pd.DataFrame([person_name_to_info[person_name] for person_name in row_index_to_person_name])[PERSON_TABLE_COLUMNS] | |
def get_max_person_number(scene_dir_path: Path) -> int: | |
scene_data = get_scene_data(scene_dir_path) | |
room_counter = get_room_counter(scene_data) | |
return room_counter.get('Bedroom', 0) | |
def create_person_page(): | |
with gr.Row(visible=False) as row: | |
image = gr.Image(width=200, scale=0.25, show_label=False, interactive=False) | |
with gr.Column(): | |
with gr.Row(): | |
name = gr.Markdown() | |
delete_button = gr.Button("Delete",size='sm') | |
age = gr.Markdown() | |
gender = gr.Markdown() | |
personality = gr.Markdown() | |
routine = gr.Markdown() | |
# gr.Markdown(f"**Occupation:** {person['occupation']}") | |
# gr.Markdown(f"**Thoughts:** {person['thoughts']}") | |
# gr.Markdown(f"**Lifestyle:** {person['lifestyle']}") | |
return { | |
'row': row, | |
'delete_button': delete_button, | |
'elements': (image, name, age, gender, personality, routine), | |
} | |
def full_view(person_name_to_description: dict[str, str]): | |
with gr.Blocks(): | |
person_pages = [create_person_page() for _ in range(TEMPLATE_ROW_NUMBER)] | |
add_person_dropdown = gr.Dropdown(choices=[person_name_to_description[key] for key in sorted(person_name_to_description.keys())], label='Add Person') | |
add_button = gr.Button('Add') | |
return (person_pages, add_person_dropdown, add_button) | |
def compact_view(): | |
return gr.Dataframe(value=pd.DataFrame(columns=PERSON_TABLE_COLUMNS), label='People Information') | |
def add_person_to_scene( | |
max_person_number: int, | |
person_dropdown_description: str, person_name_to_description: dict[str, str], | |
row_index_to_person_name: list[str], | |
person_name_to_info: dict[str, dict] | |
): | |
current_row_number = len(row_index_to_person_name) + 1 | |
add_button_visible = current_row_number < max_person_number | |
add_person_name = [key for key, value in person_name_to_description.items() if value == person_dropdown_description][0] | |
add_person_elements = person_info_to_elements(person_name_to_info[add_person_name]) | |
row_elements_update_list = [gr.update() for _ in range(TEMPLATE_ROW_NUMBER * len(add_person_elements))] | |
row_visible_list = [gr.update() for _ in range(TEMPLATE_ROW_NUMBER)] | |
row_elements_update_list[(current_row_number-1) * PERSON_ELEMENT_NUM: current_row_number * PERSON_ELEMENT_NUM] = add_person_elements | |
row_visible_list[current_row_number-1] = gr.update(visible=True) | |
row_index_to_person_name.append(add_person_name) | |
to_be_chose_person_name_list = [key for key in person_name_to_description.keys() if key not in row_index_to_person_name] | |
to_be_chose_person_description_list = [person_name_to_description[key] for key in sorted(to_be_chose_person_name_list)] | |
return row_elements_update_list + row_visible_list + [ | |
row_index_to_person_name, | |
gr.update(choices=to_be_chose_person_description_list, visible=add_button_visible), | |
gr.update(visible=add_button_visible), | |
get_person_dataframe(row_index_to_person_name, person_name_to_info), | |
gr.update(visible=True) | |
] | |
def delete_person_from_scene( | |
person_name_to_info: dict[str, dict], | |
row_index_to_person_name: list[int], | |
button_row_index: int, | |
person_name_to_description: dict[str, str], | |
*row_elements: tuple[str] | |
): | |
current_row_number = len(row_index_to_person_name) | |
row_elements = list(row_elements) | |
row_elements[button_row_index * PERSON_ELEMENT_NUM: (current_row_number-1) * PERSON_ELEMENT_NUM] = row_elements[(button_row_index + 1) * PERSON_ELEMENT_NUM: current_row_number * PERSON_ELEMENT_NUM] | |
person_row_visible = [gr.update() for _ in range(TEMPLATE_ROW_NUMBER)] | |
person_row_visible[current_row_number-1] = gr.update(visible=False) | |
row_index_to_person_name.pop(button_row_index) | |
to_be_chose_person_name_list = [key for key in person_name_to_description.keys() if key not in row_index_to_person_name] | |
to_be_chose_person_description_list = [person_name_to_description[key] for key in sorted(to_be_chose_person_name_list)] | |
return row_elements + person_row_visible + [ | |
gr.update(visible=True), | |
gr.update(choices=to_be_chose_person_description_list, visible=True), | |
row_index_to_person_name, | |
get_person_dataframe(row_index_to_person_name, person_name_to_info), | |
gr.update(visible=len(row_index_to_person_name) > 0) | |
] | |
def reset_person_rows(person_name_to_description: dict[str, str]): | |
rows_visible = [gr.update(visible=False) for _ in range(TEMPLATE_ROW_NUMBER)] | |
row_index_to_person_name = [] | |
add_button_visible = gr.update(visible=True) | |
add_person_dropdown = gr.update(choices=[person_name_to_description[key] for key in sorted(person_name_to_description.keys())], visible=True) | |
person_dataframe = pd.DataFrame(columns=PERSON_TABLE_COLUMNS) | |
return rows_visible + [row_index_to_person_name, add_button_visible, add_person_dropdown, person_dataframe] | |
def reset_person_page( | |
scene_file_name: str, | |
person_name_to_description: dict[str, str], | |
): | |
# get the bedroom number of the scene | |
scene_dir = SCENE_ROOT_DIR / scene_file_name | |
max_added_person = get_max_person_number(scene_dir) | |
return reset_person_rows(person_name_to_description) + [max_added_person, gr.update(visible=False), gr.update(visible=False)] | |
def visualize_person(person_name_to_info: gr.State, person_name_to_description: gr.State, max_added_person: int): | |
row_index_to_person_name = gr.State([]) | |
max_added_person = gr.State(max_added_person) | |
with gr.Blocks(): | |
with gr.Tab(label='Full View'): | |
person_pages, add_person_dropdown, add_button = full_view(person_name_to_description.value) | |
with gr.Tab(label='Compact View'): | |
person_dataframe = compact_view() | |
generate_button = gr.Button("Generate Video", visible=False) | |
row_elements = [element for page in person_pages for element in page['elements']] | |
rows = [page['row'] for page in person_pages] | |
delete_buttons: list[gr.Button] = [page['delete_button'] for page in person_pages] | |
add_button.click( | |
fn=add_person_to_scene, | |
inputs=[max_added_person, add_person_dropdown, person_name_to_description, row_index_to_person_name, person_name_to_info], | |
outputs=row_elements + rows + [row_index_to_person_name, add_person_dropdown, add_button, person_dataframe, generate_button] | |
) | |
for i, delete_button in enumerate(delete_buttons): | |
delete_button.click( | |
fn=delete_person_from_scene, | |
inputs=[person_name_to_info, row_index_to_person_name, gr.State(i), person_name_to_description] + row_elements, | |
outputs=row_elements + rows + [add_button, add_person_dropdown, row_index_to_person_name, person_dataframe, generate_button] | |
) | |
return row_index_to_person_name, max_added_person, row_elements, rows, delete_buttons, add_button, add_person_dropdown, person_dataframe, generate_button | |
def generate_button_click_change_state(): | |
return gr.update(visible=False), gr.update(visible=True) | |
def generate_button_click_change_video( | |
scene_file_name: str, | |
row_index_to_person_name: list[str], | |
person_name_to_file: dict[str, dict[str, any]] | |
): | |
video_path = generate_video(scene_file_name, [person_name_to_file[person_name] for person_name in row_index_to_person_name]) | |
if video_path: | |
return str(video_path) | |
return gr.update(visible=False) | |
def visualize_dynamic_generate(person_name_to_file, person_name_to_info, person_name_to_description): | |
gr.Markdown("## Scene Information") | |
scene_dropdown = visualize_scene() | |
max_person_number = get_max_person_number(SCENE_ROOT_DIR / scene_dropdown.value) | |
gr.Markdown("## Person Information") | |
row_index_to_person_name, max_person_number, _, rows, _, add_button, add_person_dropdown, person_dataframe, generate_button \ | |
= visualize_person(person_name_to_info, person_name_to_description, max_person_number) | |
video = gr.Video(visible=False) | |
clear_button = gr.Button("Clear") | |
generate_button.click( | |
fn=generate_button_click_change_video, | |
inputs=[scene_dropdown, row_index_to_person_name, person_name_to_file], | |
outputs=video | |
) | |
generate_button.click( | |
fn=generate_button_click_change_state, | |
outputs=[generate_button, video] | |
) | |
reset_inputs = [scene_dropdown, person_name_to_description] | |
reset_outputs = rows + [row_index_to_person_name, add_button, add_person_dropdown, person_dataframe, max_person_number, generate_button, video] | |
scene_dropdown.change( | |
fn=reset_person_page, | |
inputs=reset_inputs, | |
outputs=reset_outputs | |
) | |
clear_button.click( | |
fn=reset_person_page, | |
inputs=reset_inputs, | |
outputs=reset_outputs | |
) | |
return reset_inputs, reset_outputs | |
if __name__ == '__main__': | |
with gr.Blocks() as demo: | |
visualize_dynamic_generate() | |
demo.launch() | |