Spaces:
Running
Running
import argparse | |
import copy | |
import pickle as pkl | |
import os | |
from datetime import timedelta | |
from collections import defaultdict | |
from gen_schedule.persona import Person | |
from gen_schedule.data import event_constant | |
from gen_schedule.gen_utils import even_split | |
from gen_schedule.core_methods import run_gpt_prompt_core_v1, gen_core_simple_v1, event_core, schedule_core | |
def gen_event_base(person: Person, activities: set, scene_prior, batch_size=10, model_name: str="gpt4", openai_api_key: str=""): | |
receptacle_info = scene_prior['receptacle_info'] | |
object_list = scene_prior['object_info'] | |
todo_activities = person.filter_event(sorted(activities)) | |
if len(todo_activities) == 0: | |
return | |
batched_act_lists = even_split(todo_activities, batch_size) | |
all_act_locs = {} | |
for act_list in batched_act_lists: | |
act_location_event = run_gpt_prompt_core_v1.gen_activity_location(person, act_list, receptacle_info, model_name, openai_api_key) | |
all_act_locs.update(act_location_event) | |
# room prob | |
act_room_prob = defaultdict(dict) | |
for act, room_probs in all_act_locs.items(): | |
for room, prob in room_probs: | |
act_room_prob[act][room] = prob | |
all_act_loc_pairs = gen_core_simple_v1.events_to_event_loc_pair(all_act_locs) | |
batched_act_loc_pairs = even_split(sorted(all_act_loc_pairs), batch_size) | |
all_act_loc_object = [] | |
for act_loc_pairs in batched_act_loc_pairs: | |
act_location_object_event = run_gpt_prompt_core_v1.gen_activity_location_object_v2(person, act_loc_pairs, object_list, model_name, openai_api_key) | |
all_act_loc_object += act_location_object_event | |
# object prob | |
act_room_object_prob = defaultdict(dict) | |
for event_case in all_act_loc_object: | |
act, object_probs = event_case['action'], event_case['objects'] | |
for obj_name, _, prob in object_probs: | |
act_room_object_prob[act][obj_name] = prob | |
batched_act_loc_object = even_split(all_act_loc_object, int(batch_size * 0.6)) | |
all_final_events = {} | |
for act_loc_object_pairs in batched_act_loc_object: | |
activity_str = event_core.formatting_event_str_for_ask_receptacle_v1(act_loc_object_pairs) | |
act_location_object_receptacle_event = run_gpt_prompt_core_v1.gen_activity_location_object_receptacle_v2(person, activity_str, receptacle_info, model_name, openai_api_key) | |
all_final_events.update(act_location_object_receptacle_event) | |
event_base = defaultdict(dict) | |
for event, object_probs in all_final_events.items(): | |
# object_probs = {object: receptacles_probs} | |
activity, location = event.split(' @ ') | |
object_effect = {} | |
for obj_name, receptacle_probs in object_probs.items(): | |
object_effect[obj_name] = { | |
'object_prob': act_room_object_prob[event][obj_name], | |
'receptacles': receptacle_probs | |
} | |
event_base[activity][location] = { | |
'room_prob': act_room_prob[activity].get(location, 0), | |
'object_effect': object_effect | |
} | |
person.update_event(event_base) | |
return event_base | |
def gen_schedule_v1(person: Person, date_span, schedule_key='default', model_name: str="gpt4", openai_api_key: str="") -> Person: | |
st, ed = date_span | |
date_list = [] | |
curr_date = st | |
while curr_date <= ed: | |
date_list.append(curr_date) | |
curr_date += timedelta(days=1) | |
for date in date_list: | |
curr_activity_list = person.primary_activity_set.copy() | |
broad_schedule = run_gpt_prompt_core_v1.gen_broad_schedule(person, date, model_name=model_name, openai_api_key=openai_api_key) | |
person.update_general_plan(broad_schedule, date, schedule_key=schedule_key) | |
merged_broad_schedule = schedule_core.truncate_schedule(broad_schedule) | |
broad_schedule_str = schedule_core.schedule_to_str(merged_broad_schedule) | |
decomposed_schedule = run_gpt_prompt_core_v1.gen_decomposed_schedule(person, date, broad_schedule_str, curr_activity_list, model_name=model_name, openai_api_key=openai_api_key) | |
gen_activity_list = [a['activity'] for a in decomposed_schedule] | |
gen_activity_list = list(set(gen_activity_list)) | |
activity_synonym_pair = run_gpt_prompt_core_v1.merge_activity_synonyms(curr_activity_list, gen_activity_list, model_name=model_name, openai_api_key=openai_api_key) | |
_ = person.update_alias(activity_synonym_pair) | |
person.update_schedule(decomposed_schedule, date, schedule_key=schedule_key) | |
return person | |
def gen_character(persona, date_span, model_name, openai_api_key) -> dict[str, any]: | |
person = Person(persona) | |
seed_activities = event_constant.CustomActivitiesV2 | |
receptacle_info = event_constant.room_to_receptacle_str.split('\n') | |
object_info = event_constant.appeared_objects | |
scene_prior = { | |
'receptacle_info': receptacle_info, | |
'object_info': object_info | |
} | |
person.primary_activity_set.update(seed_activities) | |
person = gen_schedule_v1(person, date_span=date_span, model_name=model_name, openai_api_key=openai_api_key) | |
# validate person activity base | |
gen_event_base(person, person.primary_activity_set, scene_prior, model_name=model_name, openai_api_key=openai_api_key) | |
character_dict = person.get_character_dict() | |
return character_dict | |