Spaces:
Sleeping
Sleeping
| from typing import List | |
| from tqdm import tqdm | |
| import yaml | |
| from replay_buffer import ReplayBuffer, RolloutReplayBuffer | |
| class Pretraining: | |
| def __init__( | |
| self, | |
| file_names: List[str], | |
| model: object, | |
| replay_buffer: object, | |
| reward_function, | |
| ): | |
| self.file_names = file_names | |
| self.model = model | |
| self.replay_buffer = replay_buffer | |
| self.reward_function = reward_function | |
| def load_buffer(self): | |
| for file_name in self.file_names: | |
| print("Loading file: ", file_name) | |
| with open(file_name, "r") as file: | |
| samples = yaml.full_load(file) | |
| for i in tqdm(range(1, len(samples) - 1)): | |
| sample = samples[i] | |
| latest_scan = sample["latest_scan"] | |
| distance = sample["distance"] | |
| cos = sample["cos"] | |
| sin = sample["sin"] | |
| collision = sample["collision"] | |
| goal = sample["goal"] | |
| action = sample["action"] | |
| state, terminal = self.model.prepare_state( | |
| latest_scan, distance, cos, sin, collision, goal, action | |
| ) | |
| if terminal: | |
| continue | |
| next_sample = samples[i + 1] | |
| next_latest_scan = next_sample["latest_scan"] | |
| next_distance = next_sample["distance"] | |
| next_cos = next_sample["cos"] | |
| next_sin = next_sample["sin"] | |
| next_collision = next_sample["collision"] | |
| next_goal = next_sample["goal"] | |
| next_action = next_sample["action"] | |
| next_state, next_terminal = self.model.prepare_state( | |
| next_latest_scan, | |
| next_distance, | |
| next_cos, | |
| next_sin, | |
| next_collision, | |
| next_goal, | |
| next_action, | |
| ) | |
| reward = self.reward_function( | |
| next_goal, next_collision, action, next_latest_scan | |
| ) | |
| self.replay_buffer.add( | |
| state, action, reward, next_terminal, next_state | |
| ) | |
| return self.replay_buffer | |
| def train( | |
| self, | |
| pretraining_iterations, | |
| replay_buffer, | |
| iterations, | |
| batch_size, | |
| ): | |
| print("Running Pretraining") | |
| for _ in tqdm(range(pretraining_iterations)): | |
| self.model.train( | |
| replay_buffer=replay_buffer, | |
| iterations=iterations, | |
| batch_size=batch_size, | |
| ) | |
| print("Model Pretrained") | |
| def get_buffer( | |
| model, | |
| sim, | |
| load_saved_buffer, | |
| pretrain, | |
| pretraining_iterations, | |
| training_iterations, | |
| batch_size, | |
| buffer_size=50000, | |
| random_seed=666, | |
| file_names=["assets/data.yml"], | |
| history_len=10, | |
| ): | |
| replay_buffer = ReplayBuffer(buffer_size=buffer_size, random_seed=random_seed) | |
| if pretrain: | |
| assert ( | |
| load_saved_buffer | |
| ), "To pre-train model, load_saved_buffer must be set to True" | |
| if load_saved_buffer: | |
| pretraining = Pretraining( | |
| file_names=file_names, | |
| model=model, | |
| replay_buffer=replay_buffer, | |
| reward_function=sim.get_reward, | |
| ) # instantiate pre-trainind | |
| replay_buffer = ( | |
| pretraining.load_buffer() | |
| ) # fill buffer with experiences from the data.yml file | |
| if pretrain: | |
| pretraining.train( | |
| pretraining_iterations=pretraining_iterations, | |
| replay_buffer=replay_buffer, | |
| iterations=training_iterations, | |
| batch_size=batch_size, | |
| ) # run pre-training | |
| return replay_buffer | |