Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from pathlib import Path | |
MODELS_ROOT_PATH = Path(__file__).parent.parent / 'models' | |
INTERNVIDEO_PATH = Path(__file__).parent.parent / 'third_party' / 'InternVideo' | |
DOMAIN2PREDICATES = { | |
'walker' : ['taking a walk', 'standing up vertically on both feet', 'single-leg balancing', "standing upside down", 'high kick', 'walking', 'stepping forward', 'running fast', | |
'standing on one bended knee', 'lying down on the back with one raised leg', 'sitting on the knees', 'dog yoga pose', 'lying down horizontally', ], | |
'stickman' : ['taking a walk', 'standing up vertically', 'one leg balancing', 'high kick', 'walking', 'running fast', | |
'praying', 'lying down with one raised leg', 'dog yoga pose', 'lying down horizontally', 'punching', 'raised hands' ], | |
'cheetah' : ['jumping', 'crawling', 'running', 'flipping', 'standing up', 'hopping', 'lying down', 'falling', | |
'standing on the knees'], | |
'quadruped' : ['jumping', 'crawling', 'walking', 'standing up', | |
'hopping', 'lying down', 'falling', 'standing on the knees'], | |
'finger' : ['spin', 'touch', 'rotate', 'horizontal', 'vertical', "not moving", "is not touching", "staying far away", "staying still"], | |
'pendulum' : ['horizontal', 'vertical', 'left', 'right', | |
'swingup', 'balance'], | |
'hopper' : ['jumping', 'crawling', 'walking', 'standing up', | |
'hopping', 'lying down', 'falling', 'standing on the knees'], | |
'reacher' : ['horizontal', 'vertical', 'ball on the left', 'ball on the right', 'touch the ball with the elbow', 'touch the ball with the tip', 'arm reaches the sphere', 'rotating', 'bending', 'keeping straight', "not moving", "is not touching"], | |
'jaco' : ['horizontal', 'vertical', 'left', 'right', 'spin', 'touch', 'rotate', 'bend', 'straight', "is not touching"], | |
'kitchen' : [ "touch", "pick up", "lift", "grasp", "hold", "pull", "open", "close", | |
"push", "sweep", "slide"] + ['switch light on', 'open the microwave', 'move the kettle', 'turn on the burner'], | |
} | |
TASK2PROMPT = { | |
"quadruped_run" : 'spider running fast', | |
"quadruped_walk" : 'spider walking fast', | |
"quadruped_stand" : 'spider standing', | |
"quadruped_jump" : 'spider jumping', | |
"quadruped_two_legs" : 'on two legs', | |
"quadruped_lie_down" : 'lying down', | |
"cheetah_run" : 'running like a quadruped', | |
"cheetah_flipping" : 'quadruped rotating flips', | |
"cheetah_standing" : 'standing like a human', | |
"cheetah_lying_down" : 'lying down', | |
'stickman_walk' : 'robot walk fast clean', | |
'stickman_run' : 'robot run fast clean', | |
'stickman_stand' : 'standing', | |
'stickman_urlb_flip' : 'doing flips', | |
'stickman_flip' : 'doing flips', | |
'stickman_flipping' : 'doing flips', | |
'stickman_backflip' : 'doing backflips', | |
'stickman_one_foot' : 'stand on one foot', | |
'stickman_high_kick' : 'stand up and kick', | |
'stickman_lying_down' : 'lying down horizontally', | |
'stickman_legs_up' : 'lying down with feet up', | |
'stickman_sit_knees' : 'praying', | |
'stickman_lunge_pose' : 'lunge_pose', | |
'stickman_headstand' : 'headstand', | |
'stickman_boxing' : 'punch', | |
'stickman_hands_up' : 'standing with the hands up', | |
'walker_walk' : 'walk fast clean', | |
'walker_run' : 'run fast clean', | |
'walker_stand' : 'standing up straight', | |
'walker_urlb_flip' : 'doing backflips', | |
'walker_flip' : 'doing flips', | |
'walker_flipping' : 'doing backflips', | |
'walker_backflip' : 'doing backflips', | |
'walker_one_foot' : 'stand on one foot', | |
'walker_high_kick' : 'stand up and kick', | |
'walker_lying_down' : 'lying down horizontally', | |
'walker_arabesque' : 'arabesque position', | |
'walker_legs_up' : 'lying down with feet up', | |
'walker_sit_knees' : 'praying', | |
'walker_lunge_pose' : 'lunge_pose', | |
'walker_headstand' : 'headstand', | |
'kitchen_microwave' : 'opening the microwave fully open', | |
'kitchen_light' : 'activate the light', | |
'kitchen_burner' : 'the burner becomes red', | |
'kitchen_slide' : 'slide cabinet above the knobs', | |
'kitchen_kettle' : 'pushing up the kettle', | |
'jaco_reach_top_left' : 'robot grasp the red cube', | |
'jaco_reach_bottom_left' : 'robot grasp the red cube', | |
'jaco_reach_top_right' : 'robot grasp the red cube', | |
'jaco_reach_bottom_right' : 'robot grasp the red cube', | |
} | |
class ViCLIPGlobalInstance: | |
def __init__(self, model='internvideo2'): | |
self._instantiated = False | |
self._model = model | |
def instantiate(self, device='cuda'): | |
from torchvision.transforms import transforms as vision_transf | |
import sys | |
self._instantiated = True | |
if self._model =='internvideo2': | |
sys.path.insert(0, str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality/demo/')) | |
sys.path.insert(0, str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality')) | |
import numpy as np | |
from small_config import (Config, eval_dict_leaf) | |
from small_utils import setup_internvideo2 | |
config = Config.from_file(INTERNVIDEO_PATH / 'InternVideo2/multi_modality/demo/internvideo2_stage2_config.py') | |
config = eval_dict_leaf(config) | |
config.model.vision_encoder.num_frames = 8 | |
config.num_frames = 8 | |
config.num_frames_test = 8 | |
# # >> can be configured in case the bert model doesn't load | |
# config.model.text_encoder.pretrained = str(MODELS_ROOT_PATH / 'bert-large-uncased') | |
config.model.text_encoder.config = str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality') + "/" + config.model.text_encoder.config | |
model_pth = str(MODELS_ROOT_PATH / 'InternVideo2-stage2_1b-224p-f4.pt') | |
config.pretrained_path = model_pth | |
config['model']['vision_encoder']['pretrained'] = model_pth | |
intern_model, tokenizer = setup_internvideo2(config) | |
self.viclip_tokenizer = tokenizer | |
self.viclip = intern_model | |
self.viclip.device = device | |
self.viclip.to(self.viclip.device) | |
self.viclip.eval() | |
self.viclip.n_frames = 8 | |
self.viclip.preprocess_transf = vision_transf.Compose([ | |
vision_transf.Resize(size=(224, 224), interpolation=vision_transf.InterpolationMode.BILINEAR), | |
vision_transf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | |
sys.path.pop(0) | |
sys.path.pop(0) | |
else: | |
raise NotImplementedError(f"Model {self._model} not implemented") | |
vid_feat = self.viclip.get_vid_features(torch.zeros(1,self.viclip.n_frames,3,224,224, device=self.viclip.device)) | |
self.viclip_emb_dim = vid_feat.shape[1] | |
def report_text2video(agent, data,): | |
report = {} | |
domain = agent.cfg.task.split('_')[0] | |
labels_list = DOMAIN2PREDICATES[domain] | |
wm = world_model = agent.wm | |
decoder = world_model.heads['decoder'] # B, T, C, H, W | |
connector = agent.wm.connector | |
n_frames = connector.n_frames | |
if hasattr(world_model, 'viclip_model'): | |
clip = world_model.viclip_model | |
else: | |
# Get ViCLIP | |
viclip_global_instance = globals()['viclip_global_instance'] | |
if not viclip_global_instance._instantiated: | |
viclip_global_instance.instantiate() | |
clip = viclip_global_instance.viclip | |
# Get text(video) embed | |
text_feat = [] | |
for text in labels_list: | |
with torch.no_grad(): | |
text_feat.append(clip.get_txt_feat(text,)) | |
text_feat = torch.stack(text_feat, dim=0) | |
# Check device is right | |
video_embed = text_feat.to(agent.device) | |
B = video_embed.shape[0] | |
# Get actions | |
video_embed = video_embed.repeat(1,n_frames, 1) | |
# Imagine | |
prior = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=False, reset_every_n_frames=False, denoise=True) | |
prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5 | |
report[f'text_to_video'] = prior_recon | |
return report | |
def max_cosine_similarity(u, v, dim=-1): | |
max_norm = torch.max(torch.norm(u, dim=dim), torch.norm(v, dim=dim)).unsqueeze(-1) | |
return torch.sum((u / max_norm) * (v / max_norm), dim=dim) | |
def neg_mse_fn(a, b, dim=-1, scale=True): | |
dist = - torch.norm(a - b, dim=dim) | |
if scale: | |
dist = dist / np.sqrt(a.shape[-1]).item() | |
return dist | |
def compute_reward(agent, agent_seq, target_seq, score_fn='cosine',): | |
if score_fn in ['cosine', 'max_cosine', 'neg_mse', 'exp_neg_mse']: | |
distance_fn = dict(cosine=F.cosine_similarity, max_cosine=max_cosine_similarity, neg_mse=neg_mse_fn, exp_neg_mse=neg_mse_fn)[score_fn] | |
target_stoch = agent.wm.connector.get_stoch( target_seq ) | |
agent_stoch = agent.wm.rssm.get_stoch( agent_seq ) | |
conv_target = agent.wm.heads['decoder']._conv_in[0](target_stoch) | |
conv_agent = agent.wm.heads['decoder']._conv_in[0](agent_stoch) | |
reward = distance_fn(conv_target, conv_agent, dim=-1) | |
if score_fn == 'exp_neg_mse': | |
reward = torch.exp(reward) | |
elif score_fn == 'neg_kl': | |
agent_dist = agent.wm.rssm.get_dist( agent_seq ) | |
target_dist = agent.wm.connector.get_dist( target_seq ) | |
reward = -torch.distributions.kl_divergence(agent_dist, target_dist,) | |
# scaling factor ( x log x w.r.t. to classes, or just x) | |
if 'logit' in target_seq: | |
reward = reward / ( np.log(target_seq['logit'].shape[-1]) * target_seq['logit'].shape[-2] ) | |
else: | |
reward = reward / target_seq['mean'].shape[-1] | |
elif score_fn == 'max_like': | |
agent_dist = agent.wm.rssm.get_dist( agent_seq ) | |
target_sample = target_seq['stoch'] | |
reward = agent_dist.log_prob(target_sample) | |
elif score_fn == 'combo': | |
return compute_reward(agent, agent_seq, target_seq, 'cosine') + compute_reward(agent, agent_seq, target_seq, 'neg_kl') | |
else: | |
raise NotImplementedError(f"{score_fn} reward not implemented") | |
return reward | |
def video_text_reward(agent, seq, score_fn='cosine', | |
sample_for_target=False, weighted_align=False, align_initial=False, align_sequence=False, | |
task_prompt='', skip_first_target=False, **kwargs): | |
wm = world_model = agent.wm | |
connector = agent.wm.connector | |
n_frames = connector.n_frames | |
T, B = seq['deter'].shape[:2] | |
imagined_steps = T | |
if not hasattr(agent, 'unconditional_target'): | |
if hasattr(world_model, 'viclip_model'): | |
clip = world_model.viclip_model | |
else: | |
# Get ViCLIP | |
viclip_global_instance = globals()['viclip_global_instance'] | |
if not viclip_global_instance._instantiated: | |
viclip_global_instance.instantiate() | |
clip = viclip_global_instance.viclip | |
if task_prompt != '': | |
task = [task_prompt] | |
else: | |
task = [ TASK2PROMPT[agent.cfg.task] ] | |
# Get text(video) embed | |
with torch.no_grad(): | |
text_feat = clip.get_txt_feat(task[0],) | |
# Check device is right | |
video_embed = text_feat.to(agent.device) | |
# Unconditional gen | |
if skip_first_target: | |
video_embed = video_embed.reshape(1, 1, -1).repeat(B, imagined_steps + 1, 1) | |
unconditional_stats = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=sample_for_target, reset_every_n_frames=False, denoise=True) | |
unconditional_stats = { k: v[:,1:].permute([1,0] + list(range(2, len(v.shape)))) for k,v in unconditional_stats.items() } | |
else: | |
video_embed = video_embed.reshape(1, 1, -1).repeat(B, imagined_steps, 1) | |
unconditional_stats = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=sample_for_target, reset_every_n_frames=False, denoise=True) | |
unconditional_stats = { k: v.permute([1,0] + list(range(2, len(v.shape)))) for k,v in unconditional_stats.items() } | |
agent.unconditional_target = unconditional_stats | |
else: | |
unconditional_stats = agent.unconditional_target | |
agent_seq = seq | |
target_seq = unconditional_stats | |
if align_initial: | |
assert not align_sequence, 'Cannot align initial and sequence at the same time' | |
init_seq = { k: v[0] for k,v in target_seq.items() } | |
init_score = compute_reward(agent, agent_seq, init_seq, score_fn=score_fn,) | |
if weighted_align: | |
w = 0.99 * torch.ones_like(init_score, device=init_score.device) | |
w = torch.cumprod(w, dim=1) | |
init_score = w * init_score | |
# | |
best_indexes_one_hot = F.one_hot(torch.argmax(init_score, dim=0), num_classes=target_seq['stoch'].shape[0]) | |
ts_idx = torch.clip(torch.cumsum(torch.cumsum(best_indexes_one_hot, dim=1), dim=1) - 1, min=0).T | |
new_target_seq = {} | |
for k,v in target_seq.items(): | |
if len(v.shape) == 4: | |
new_ts = ts_idx.unsqueeze(-1).unsqueeze(-1).repeat(1,1, v.shape[-2], v.shape[-1]) | |
else: | |
new_ts = ts_idx.unsqueeze(-1).repeat(1,1, v.shape[-1]) | |
new_target_seq[k] = torch.gather(v, 0, new_ts) # out[i][j][k] = input[index[i][j][k]][j][k] | |
return compute_reward(agent, agent_seq, new_target_seq, score_fn=score_fn,).unsqueeze(-1) | |
elif align_sequence: | |
align_score = [] | |
get_prev_a_b = lambda d, a, b : { k : v[a:b] for k,v in d.items() } | |
shorter_target_seq = get_prev_a_b(unconditional_stats, 0, n_frames) | |
for t in range(T-n_frames): | |
cur_agent_seq = get_prev_a_b(seq, t, t+n_frames) | |
score = compute_reward(agent, cur_agent_seq, shorter_target_seq, score_fn=score_fn,).mean(dim=0) # 0 is time dimension | |
align_score.append(score) | |
align_score = torch.stack(align_score, dim=0) | |
if weighted_align: | |
w = 0.99 * torch.ones_like(align_score, device=align_score.device) | |
w = torch.cumprod(w, dim=1) | |
align_score = w * align_score | |
best_indexes_one_hot = F.one_hot(torch.argmax(align_score, dim=0), num_classes=target_seq['stoch'].shape[0]) | |
ts_idx = torch.clip(torch.cumsum(torch.cumsum(best_indexes_one_hot, dim=1), dim=1) - 1, min=0).T | |
new_target_seq = {} | |
for k,v in target_seq.items(): | |
if len(v.shape) == 4: | |
new_ts = ts_idx.unsqueeze(-1).unsqueeze(-1).repeat(1,1, v.shape[-2], v.shape[-1]) | |
else: | |
new_ts = ts_idx.unsqueeze(-1).repeat(1,1, v.shape[-1]) | |
new_target_seq[k] = torch.gather(v, 0, new_ts) # out[i][j][k] = input[index[i][j][k]][j][k] | |
return compute_reward(agent, agent_seq, new_target_seq, score_fn=score_fn,).unsqueeze(-1) | |
else: | |
neg_kl = compute_reward(agent, agent_seq, target_seq, score_fn=score_fn,) | |
return neg_kl.unsqueeze(-1) | |
global viclip_global_instance | |
viclip_global_instance = ViCLIPGlobalInstance() |