| | import logging |
| | import threading |
| | import time |
| |
|
| | from openpi_client.runtime import agent as _agent |
| | from openpi_client.runtime import environment as _environment |
| | from openpi_client.runtime import subscriber as _subscriber |
| |
|
| |
|
| | class Runtime: |
| | """The core module orchestrating interactions between key components of the system.""" |
| |
|
| | def __init__( |
| | self, |
| | environment: _environment.Environment, |
| | agent: _agent.Agent, |
| | subscribers: list[_subscriber.Subscriber], |
| | max_hz: float = 0, |
| | num_episodes: int = 1, |
| | max_episode_steps: int = 0, |
| | ) -> None: |
| | self._environment = environment |
| | self._agent = agent |
| | self._subscribers = subscribers |
| | self._max_hz = max_hz |
| | self._num_episodes = num_episodes |
| | self._max_episode_steps = max_episode_steps |
| |
|
| | self._in_episode = False |
| | self._episode_steps = 0 |
| |
|
| | def run(self) -> None: |
| | """Runs the runtime loop continuously until stop() is called or the environment is done.""" |
| | for _ in range(self._num_episodes): |
| | self._run_episode() |
| |
|
| | |
| | self._environment.reset() |
| |
|
| | def run_in_new_thread(self) -> threading.Thread: |
| | """Runs the runtime loop in a new thread.""" |
| | thread = threading.Thread(target=self.run) |
| | thread.start() |
| | return thread |
| |
|
| | def mark_episode_complete(self) -> None: |
| | """Marks the end of an episode.""" |
| | self._in_episode = False |
| |
|
| | def _run_episode(self) -> None: |
| | """Runs a single episode.""" |
| | logging.info("Starting episode...") |
| | self._environment.reset() |
| | self._agent.reset() |
| | for subscriber in self._subscribers: |
| | subscriber.on_episode_start() |
| |
|
| | self._in_episode = True |
| | self._episode_steps = 0 |
| | step_time = 1 / self._max_hz if self._max_hz > 0 else 0 |
| | last_step_time = time.time() |
| |
|
| | while self._in_episode: |
| | self._step() |
| | self._episode_steps += 1 |
| |
|
| | |
| | now = time.time() |
| | dt = now - last_step_time |
| | if dt < step_time: |
| | time.sleep(step_time - dt) |
| | last_step_time = time.time() |
| | else: |
| | last_step_time = now |
| |
|
| | logging.info("Episode completed.") |
| | for subscriber in self._subscribers: |
| | subscriber.on_episode_end() |
| |
|
| | def _step(self) -> None: |
| | """A single step of the runtime loop.""" |
| | observation = self._environment.get_observation() |
| | action = self._agent.get_action(observation) |
| | self._environment.apply_action(action) |
| |
|
| | for subscriber in self._subscribers: |
| | subscriber.on_step(observation, action) |
| |
|
| | if self._environment.is_episode_complete() or ( |
| | self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps |
| | ): |
| | self.mark_episode_complete() |
| |
|