In [None]:
from pathlib import Path 
import os
import glob
import json
import sys
sys.path.append(str(Path(os.path.abspath('')).parent))

import torch
import torch.distributions as D
import numpy as np
import torch.nn.functional as F

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.animation as animation

import wandb
from tqdm import tqdm
api = wandb.Api()

agent_path = Path(os.path.abspath('')).parent / 'models' / 'genrl_stickman_500k_2.pt'
print("Model path", agent_path)

agent = torch.load(agent_path)

In [None]:
from tools.genrl_utils import ViCLIPGlobalInstance, DOMAIN2PREDICATES
model_name = getattr(agent.cfg, 'viclip_model', 'viclip')
# Get ViCLIP
if 'viclip_global_instance' not in locals() or model_name != viclip_global_instance._model:
 viclip_global_instance = ViCLIPGlobalInstance(model_name)
 if not viclip_global_instance._instantiated:
 print("Instantiating")
 viclip_global_instance.instantiate()
 clip = viclip_global_instance.viclip
 tokenizer = viclip_global_instance.viclip_tokenizer

In [None]:
import cv2

def get_vid_feat(frames, clip):
 return clip.get_vid_features(frames,)

def _frame_from_video(video):
 while video.isOpened():
 success, frame = video.read()
 if success:
 yield frame
 else:
 break

v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)
v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)
def normalize(data):
 return (data/255.0-v_mean)/v_std

def denormalize(data):
 return (((data * v_std) + v_mean) * 255) 

def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')):
 vid_list = [*vid_list[0]]
 assert(len(vid_list) >= fnum)
 vid_list = [cv2.resize(x, target_size) for x in vid_list]
 vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list]
 vid_tube = np.concatenate(vid_tube, axis=1)
 vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
 vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float()
 return vid_tube


def get_video_feat(frames, device=torch.device('cuda'), flip=False):
 # Image
 if frames.shape[1] == 1:
 frames = frames.transpose(1,0,2,3,4).repeat(8, axis=0).transpose(1,0,2,3,4)

 # Short video
 if frames.shape[1] == 4:
 frames = frames.transpose(1,0,2,3,4).repeat(2, axis=0).transpose(1,0,2,3,4)

 k = max(frames.shape[1] // 128, 1)
 frames = frames[:, ::k]
 
 # Horizontally flip
 if flip:
 frames = np.flip(frames, axis=-2)

 print(frames.shape,)
 chosen_frames = frames[:, :8]
 chosen_frames = frames2tensor(chosen_frames, device=device)
 vid_feat = get_vid_feat(chosen_frames, clip,)
 return vid_feat, chosen_frames

VIDEO_PATH = Path(os.path.abspath('')).parent / 'assets' / 'video_samples'
video_name = 'headstand.mp4'

video_file_path = str(VIDEO_PATH / video_name)
print(video_file_path)
video = cv2.VideoCapture(video_file_path)
frames = np.expand_dims(np.stack([ cv2.cvtColor(x, cv2.COLOR_BGR2RGB) for x in _frame_from_video(video)], axis=0), axis=0)
print('Video length:', frames.shape[1])
with torch.no_grad():
 vid_feat, frames_feat = get_video_feat(frames, flip=False)
print(vid_feat.shape)
plt.imshow(frames[0,0])

In [None]:
video_embed = vid_feat
DENOISE = True

T = video_embed.shape[0]

from torchvision.transforms import transforms as vision_trans
trasnf = vision_trans.Resize(size=(64, 64), interpolation=vision_trans.InterpolationMode.NEAREST)

wm = world_model = agent.wm
connector = agent.wm.connector
decoder = world_model.heads['decoder']
n_frames = connector.n_frames


with torch.no_grad():
 # Get actions
 video_embed = video_embed.unsqueeze(1).repeat(1,n_frames, 1).reshape(1, n_frames * T, -1)
 action = wm.connector.get_action(video_embed)

 # Imagine
 prior = wm.connector.video_imagine(video_embed, None, sample=False, reset_every_n_frames=False, denoise=DENOISE)
 prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5

 # Plotting video
 ims = []
 fig, axes = plt.subplots(1, 1, figsize=(4, 8), frameon=False)
 fig.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
 fig.set_size_inches(4,2)

 for t in range(prior_recon.shape[1]):
 toadd = []
 for b in range(prior_recon.shape[0]):
 ax = axes
 ax.set_axis_off()
 img = cv2.resize((np.clip(prior_recon[b, t].cpu().permute(1,2,0), 0, 1).numpy() *255).astype(np.uint8), (224,224))
 orig_img = denormalize(frames_feat[b, t].cpu().permute(1,2,0) ).numpy().astype(np.uint8)
 frame = ax.imshow(np.concatenate([orig_img, img], axis=1)) 
 toadd.append(frame) # add both the image and the text to the list of artists 
 ims.append(toadd)

 anim = animation.ArtistAnimation(fig, ims, interval=700, blit=True, repeat_delay=700, )

 # Save GIFs
 writer = animation.PillowWriter(fps=15, metadata=dict(artist='Me'), bitrate=1800,)
 domain = agent.cfg.task.split('_')[0]
 os.makedirs(f'videos/{domain}/video2video', exist_ok=True)
 file_path = f'videos/{domain}/video2video/{video_name[:-4].replace(" ","_")}.gif'
 anim.save(file_path, writer=writer, )
