worldmem / app.py
xizaoqu
update
de66f2b
raw
history blame contribute delete
23.5 kB
import gradio as gr
import time
import sys
import subprocess
import time
from pathlib import Path
import hydra
from omegaconf import DictConfig, OmegaConf
from omegaconf.omegaconf import open_dict
import numpy as np
import torch
import torchvision.transforms as transforms
import cv2
import subprocess
from PIL import Image
from datetime import datetime
import spaces
from algorithms.worldmem import WorldMemMinecraft
from huggingface_hub import hf_hub_download
import tempfile
torch.set_float32_matmul_precision("high")
ACTION_KEYS = [
"inventory",
"ESC",
"hotbar.1",
"hotbar.2",
"hotbar.3",
"hotbar.4",
"hotbar.5",
"hotbar.6",
"hotbar.7",
"hotbar.8",
"hotbar.9",
"forward",
"back",
"left",
"right",
"cameraY",
"cameraX",
"jump",
"sneak",
"sprint",
"swapHands",
"attack",
"use",
"pickItem",
"drop",
]
# Mapping of input keys to action names
KEY_TO_ACTION = {
"Q": ("forward", 1),
"E": ("back", 1),
"W": ("cameraY", -1),
"S": ("cameraY", 1),
"A": ("cameraX", -1),
"D": ("cameraX", 1),
"U": ("drop", 1),
"N": ("noop", 1),
"1": ("hotbar.1", 1),
}
example_images = [
["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
]
def load_custom_checkpoint(algo, checkpoint_path):
hf_ckpt = str(checkpoint_path).split('/')
repo_id = '/'.join(hf_ckpt[:2])
file_name = '/'.join(hf_ckpt[2:])
model_path = hf_hub_download(repo_id=repo_id,
filename=file_name)
ckpt = torch.load(model_path, map_location=torch.device('cpu'))
algo.load_state_dict(ckpt['state_dict'], strict=False)
def parse_input_to_tensor(input_str):
"""
Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
of the corresponding action key.
Args:
input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS").
Returns:
torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action.
"""
# Get the length of the input sequence
seq_len = len(input_str)
# Initialize a zero tensor of shape (seq_len, 25)
action_tensor = torch.zeros((seq_len, 25))
# Iterate through the input string and update the corresponding positions
for i, char in enumerate(input_str):
action, value = KEY_TO_ACTION.get(char.upper()) # Convert to uppercase to handle case insensitivity
if action and action in ACTION_KEYS:
index = ACTION_KEYS.index(action)
action_tensor[i, index] = value # Set the corresponding action index to 1
return action_tensor
def load_image_as_tensor(image_path: str) -> torch.Tensor:
"""
Load an image and convert it to a 0-1 normalized tensor.
Args:
image_path (str): Path to the image file.
Returns:
torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1].
"""
if isinstance(image_path, str):
image = Image.open(image_path).convert("RGB") # Ensure it's RGB
else:
image = image_path
transform = transforms.Compose([
transforms.ToTensor(), # Converts to tensor and normalizes to [0,1]
])
return transform(image)
def enable_amp(model, precision="16-mixed"):
original_forward = model.forward
def amp_forward(*args, **kwargs):
with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
return original_forward(*args, **kwargs)
model.forward = amp_forward
return model
memory_frames = []
input_history = ""
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
DESERT_IMAGE = "assets/desert.png"
SAVANNA_IMAGE = "assets/savanna.png"
PLAINS_IMAGE = "assets/plans.png"
PLACE_IMAGE = "assets/place.png"
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
device = torch.device('cuda')
def save_video(frames, path="output.mp4", fps=10):
h, w, _ = frames[0].shape
out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h))
for frame in frames:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
ffmpeg_cmd = [
"ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path
]
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return path
cfg = OmegaConf.load("configurations/huggingface.yaml")
worldmem = WorldMemMinecraft(cfg)
load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffusion_path)
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
worldmem.to("cuda").eval()
# worldmem = enable_amp(worldmem, precision="16-mixed")
actions = np.zeros((1, 25), dtype=np.float32)
poses = np.zeros((1, 5), dtype=np.float32)
def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
self_poses, self_memory_c2w, self_frame_idx):
return 5 * len(action) if self_actions is not None else 5
@spaces.GPU(duration=get_duration_single_image_to_long_video)
def run_interactive(first_frame, action, first_pose, device, self_frames, self_actions,
self_poses, self_memory_c2w, self_frame_idx):
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
action,
first_pose,
device=device,
self_frames=self_frames,
self_actions=self_actions,
self_poses=self_poses,
self_memory_c2w=self_memory_c2w,
self_frame_idx=self_frame_idx)
return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
worldmem.sampling_timesteps = denoising_steps
worldmem.diffusion_model.sampling_timesteps = denoising_steps
sampling_timesteps_state = denoising_steps
print("set denoising steps to", worldmem.sampling_timesteps)
return sampling_timesteps_state
def set_context_length(context_length, sampling_context_length_state):
worldmem.n_tokens = context_length
sampling_context_length_state = context_length
print("set context length to", worldmem.n_tokens)
return sampling_context_length_state
def set_memory_length(memory_length, sampling_memory_length_state):
worldmem.condition_similar_length = memory_length
sampling_memory_length_state = memory_length
print("set memory length to", worldmem.condition_similar_length)
return sampling_memory_length_state
def generate(keys, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
input_actions = parse_input_to_tensor(keys)
if self_frames is None:
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
actions[0],
poses[0],
device=device,
self_frames=self_frames,
self_actions=self_actions,
self_poses=self_poses,
self_memory_c2w=self_memory_c2w,
self_frame_idx=self_frame_idx)
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
input_actions,
None,
device=device,
self_frames=self_frames,
self_actions=self_actions,
self_poses=self_poses,
self_memory_c2w=self_memory_c2w,
self_frame_idx=self_frame_idx)
memory_frames = np.concatenate([memory_frames, new_frame[:,0]])
out_video = memory_frames.transpose(0,2,3,1).copy()
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
out_video = (out_video * 255).astype(np.uint8)
last_frame = out_video[-1].copy()
border_thickness = 2
out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
save_video(out_video, temporal_video_path)
input_history += keys
# now = datetime.now()
# folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
# folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
# os.makedirs(folder_path, exist_ok=True)
# data_dict = {
# "input_history": input_history,
# "memory_frames": memory_frames,
# "self_frames": self_frames,
# "self_actions": self_actions,
# "self_poses": self_poses,
# "self_memory_c2w": self_memory_c2w,
# "self_frame_idx": self_frame_idx,
# }
# np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
return last_frame, temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
def reset(selected_image):
self_frames = None
self_poses = None
self_actions = None
self_memory_c2w = None
self_frame_idx = None
memory_frames = load_image_as_tensor(selected_image).numpy()[None]
input_history = ""
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
actions[0],
poses[0],
device=device,
self_frames=self_frames,
self_actions=self_actions,
self_poses=self_poses,
self_memory_c2w=self_memory_c2w,
self_frame_idx=self_frame_idx)
return input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
def on_image_click(selected_image):
input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = reset(selected_image)
return input_history, selected_image, selected_image, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
def set_memory(examples_case, image_display, log_output, slider_denoising_step, slider_context_length, slider_memory_length):
if examples_case == '1':
data_bundle = np.load("assets/examples/case1.npz")
input_history = data_bundle['input_history'].item()
memory_frames = data_bundle['memory_frames']
self_frames = data_bundle['self_frames']
self_actions = data_bundle['self_actions']
self_poses = data_bundle['self_poses']
self_memory_c2w = data_bundle['self_memory_c2w']
self_frame_idx = data_bundle['self_frame_idx']
elif examples_case == '2':
data_bundle = np.load("assets/examples/case2.npz")
input_history = data_bundle['input_history'].item()
memory_frames = data_bundle['memory_frames']
self_frames = data_bundle['self_frames']
self_actions = data_bundle['self_actions']
self_poses = data_bundle['self_poses']
self_memory_c2w = data_bundle['self_memory_c2w']
self_frame_idx = data_bundle['self_frame_idx']
elif examples_case == '3':
data_bundle = np.load("assets/examples/case3.npz")
input_history = data_bundle['input_history'].item()
memory_frames = data_bundle['memory_frames']
self_frames = data_bundle['self_frames']
self_actions = data_bundle['self_actions']
self_poses = data_bundle['self_poses']
self_memory_c2w = data_bundle['self_memory_c2w']
self_frame_idx = data_bundle['self_frame_idx']
elif examples_case == '4':
data_bundle = np.load("assets/examples/case4.npz")
input_history = data_bundle['input_history'].item()
memory_frames = data_bundle['memory_frames']
self_frames = data_bundle['self_frames']
self_actions = data_bundle['self_actions']
self_poses = data_bundle['self_poses']
self_memory_c2w = data_bundle['self_memory_c2w']
self_frame_idx = data_bundle['self_frame_idx']
out_video = memory_frames.transpose(0,2,3,1)
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
out_video = (out_video * 255).astype(np.uint8)
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
save_video(out_video, temporal_video_path)
return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
css = """
h1 {
text-align: center;
display:block;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# WORLDMEM: Long-term Consistent World Simulation with Memory
<div style="text-align: center;">
<a style="display:inline-block; margin: 0 10px;" href="https://github.com/xizaoqu/WorldMem">
<img src="https://img.shields.io/badge/GitHub-Repository-black?logo=github"/>
</a>
<a style="display:inline-block; margin: 0 10px;" href="https://xizaoqu.github.io/worldmem/">
<img src="https://img.shields.io/badge/Project_Page-blue"/>
</a>
<a style="display:inline-block; margin: 0 10px;" href="https://arxiv.org/abs/2504.12369">
<img src="https://img.shields.io/badge/arXiv-Paper-red"/>
</a>
</div>
"""
)
gr.Markdown(
"""
## 🚀 How to Explore WorldMem
Follow these simple steps to get started:
1. **Choose a scene**.
2. **Input your action sequence**.
3. **Click "Generate"**.
- You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
- For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
- ⭐️ If you like this project, please [give it a star on GitHub](https://github.com/xizaoqu/WorldMem)!
- 💬 For questions or feedback, feel free to open an issue or email me at **zeqixiao1@gmail.com**.
Happy exploring! 🌍
"""
)
example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
"turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
"turn right→go forward→turn right": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
"turn right→look up→turn right→look down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
"put item→go backward→put item→go backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
selected_image = gr.State(ICE_PLAINS_IMAGE)
with gr.Row(variant="panel"):
with gr.Column():
gr.Markdown("🖼️ Start from this frame.")
image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
with gr.Column():
gr.Markdown("🎞️ Generated videos. New contents are marked in red box.")
video_display = gr.Video(autoplay=True, loop=True)
gr.Markdown("### 🏞️ Choose a scene and start generation.")
with gr.Row():
image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
with gr.Row(variant="panel"):
with gr.Column(scale=2):
gr.Markdown("### 🕹️ Input action sequences for interaction.")
input_box = gr.Textbox(label="Action Sequences", placeholder="Enter action sequences here, e.g. (AAAAAAAAAAAADDDDDDDDDDDD)", lines=1, max_lines=1)
log_output = gr.Textbox(label="History Sequences", interactive=False)
gr.Markdown(
"""
### 💡 Action Key Guide
<pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
W: Turn up S: Turn down A: Turn left D: Turn right
Q: Go forward E: Go backward N: No-op U: Use item
</pre>
"""
)
gr.Markdown("### 👇 Click to quickly set action sequence examples.")
with gr.Row():
buttons = []
for action_key in list(example_actions.keys())[:2]:
with gr.Column(scale=len(action_key)):
buttons.append(gr.Button(action_key))
with gr.Row():
for action_key in list(example_actions.keys())[2:4]:
with gr.Column(scale=len(action_key)):
buttons.append(gr.Button(action_key))
with gr.Row():
for action_key in list(example_actions.keys())[4:6]:
with gr.Column(scale=len(action_key)):
buttons.append(gr.Button(action_key))
with gr.Column(scale=1):
submit_button = gr.Button("🎬 Generate!", variant="primary")
reset_btn = gr.Button("🔄 Reset")
gr.Markdown("### ⚙️ Advanced Settings")
slider_denoising_step = gr.Slider(
minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
label="Denoising Steps",
info="Higher values yield better quality but slower speed"
)
slider_context_length = gr.Slider(
minimum=2, maximum=10, value=worldmem.n_tokens, step=1,
label="Context Length",
info="How many previous frames in temporal context window."
)
slider_memory_length = gr.Slider(
minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1,
label="Memory Length",
info="How many previous frames in memory window."
)
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
sampling_context_length_state = gr.State(worldmem.n_tokens)
sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
memory_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
self_frames = gr.State()
self_actions = gr.State()
self_poses = gr.State()
self_memory_c2w = gr.State()
self_frame_idx = gr.State()
def set_action(action):
return action
for button, action_key in zip(buttons, list(example_actions.keys())):
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
example_case = gr.Textbox(label="Case", visible=False)
image_output = gr.Image(visible=False)
examples = gr.Examples(
examples=example_images,
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
cache_examples=False
)
example_case.change(
fn=set_memory,
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
)
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
slider_memory_length.change(fn=set_memory_length, inputs=[slider_memory_length, sampling_memory_length_state], outputs=sampling_memory_length_state)
demo.launch()