File size: 2,422 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import imageio
from sim.main import InteractiveDigitalWorld
from sim.simulator import GenieSimulator
from sim.policy import RandomPlanarQuadDirectionalPolicy

if __name__ == '__main__':
    def draw_action_arrow_to_image(image: np.ndarray, action: np.ndarray) -> np.ndarray:
        action = action[0]  # remove `stride` dimension
        assert action[0] * action[1] == 0
        arrow_image = imageio.imread('sim/assets/arrow.jpg')
        if action[0] > 0:   # `s`
            arrow_image = np.flipud(arrow_image)
        elif action[1] < 0: # `a`
            arrow_image = np.rot90(arrow_image)
        elif action[1] > 0: # `d`
            arrow_image = np.rot90(arrow_image, -1)
        else:
            pass            # `w`
        image[0:arrow_image.shape[0], 0:arrow_image.shape[1]] = arrow_image
        return image
    genie_simulator = GenieSimulator(
        # image_encoder_type="magvit",
        # image_encoder_ckpt="data/magvit2.ckpt",
        # quantize=True,
        # backbone_type="stmaskgit",
        # backbone_ckpt="data/genie_lang/step_5",
        # prompt_horizon=8,

        image_encoder_type='temporalvae',
        image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
        quantize=False,
        backbone_type="stmar",
        backbone_ckpt="data/language_table_scratch_mar_dynamics_gpu_8_nodes_4_16g/step_40000",
        # backbone_ckpt="data/genie_lang/step_5",
        prompt_horizon=11,
        
        action_stride=1,
        domain='language_table',
        post_processor=draw_action_arrow_to_image
    )
    # use whatever current state is as the initial state
    current_image = imageio.imread('sim/assets/langtable_prompt.png')
    image_prompt = np.tile(
        current_image, (genie_simulator.prompt_horizon, 1, 1, 1)
        ).astype(np.uint8)
    action_prompt = np.zeros(
        (genie_simulator.prompt_horizon, genie_simulator.action_stride, 2)
        ).astype(np.float32)
    genie_simulator.set_initial_state((image_prompt, action_prompt))
    random_policy = RandomPlanarQuadDirectionalPolicy(increment=0.05)    # as IRASIM
    playground = InteractiveDigitalWorld(
        simulator=genie_simulator,
        policy=random_policy,
        offscreen=True,
        window_size=(512, 512)
    )

    for _ in range(50):
        playground.step()

    playground.save_video(save_path='test.mp4', as_gif=False)
    playground.close()