Spaces:
Running
on
Zero
Running
on
Zero
File size: 15,180 Bytes
2d9a728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
MODELS_ROOT_PATH = Path(__file__).parent.parent / 'models'
INTERNVIDEO_PATH = Path(__file__).parent.parent / 'third_party' / 'InternVideo'
DOMAIN2PREDICATES = {
'walker' : ['taking a walk', 'standing up vertically on both feet', 'single-leg balancing', "standing upside down", 'high kick', 'walking', 'stepping forward', 'running fast',
'standing on one bended knee', 'lying down on the back with one raised leg', 'sitting on the knees', 'dog yoga pose', 'lying down horizontally', ],
'stickman' : ['taking a walk', 'standing up vertically', 'one leg balancing', 'high kick', 'walking', 'running fast',
'praying', 'lying down with one raised leg', 'dog yoga pose', 'lying down horizontally', 'punching', 'raised hands' ],
'cheetah' : ['jumping', 'crawling', 'running', 'flipping', 'standing up', 'hopping', 'lying down', 'falling',
'standing on the knees'],
'quadruped' : ['jumping', 'crawling', 'walking', 'standing up',
'hopping', 'lying down', 'falling', 'standing on the knees'],
'finger' : ['spin', 'touch', 'rotate', 'horizontal', 'vertical', "not moving", "is not touching", "staying far away", "staying still"],
'pendulum' : ['horizontal', 'vertical', 'left', 'right',
'swingup', 'balance'],
'hopper' : ['jumping', 'crawling', 'walking', 'standing up',
'hopping', 'lying down', 'falling', 'standing on the knees'],
'reacher' : ['horizontal', 'vertical', 'ball on the left', 'ball on the right', 'touch the ball with the elbow', 'touch the ball with the tip', 'arm reaches the sphere', 'rotating', 'bending', 'keeping straight', "not moving", "is not touching"],
'jaco' : ['horizontal', 'vertical', 'left', 'right', 'spin', 'touch', 'rotate', 'bend', 'straight', "is not touching"],
'kitchen' : [ "touch", "pick up", "lift", "grasp", "hold", "pull", "open", "close",
"push", "sweep", "slide"] + ['switch light on', 'open the microwave', 'move the kettle', 'turn on the burner'],
}
TASK2PROMPT = {
"quadruped_run" : 'spider running fast',
"quadruped_walk" : 'spider walking fast',
"quadruped_stand" : 'spider standing',
"quadruped_jump" : 'spider jumping',
"quadruped_two_legs" : 'on two legs',
"quadruped_lie_down" : 'lying down',
"cheetah_run" : 'running like a quadruped',
"cheetah_flipping" : 'quadruped rotating flips',
"cheetah_standing" : 'standing like a human',
"cheetah_lying_down" : 'lying down',
'stickman_walk' : 'robot walk fast clean',
'stickman_run' : 'robot run fast clean',
'stickman_stand' : 'standing',
'stickman_urlb_flip' : 'doing flips',
'stickman_flip' : 'doing flips',
'stickman_flipping' : 'doing flips',
'stickman_backflip' : 'doing backflips',
'stickman_one_foot' : 'stand on one foot',
'stickman_high_kick' : 'stand up and kick',
'stickman_lying_down' : 'lying down horizontally',
'stickman_legs_up' : 'lying down with feet up',
'stickman_sit_knees' : 'praying',
'stickman_lunge_pose' : 'lunge_pose',
'stickman_headstand' : 'headstand',
'stickman_boxing' : 'punch',
'stickman_hands_up' : 'standing with the hands up',
'walker_walk' : 'walk fast clean',
'walker_run' : 'run fast clean',
'walker_stand' : 'standing up straight',
'walker_urlb_flip' : 'doing backflips',
'walker_flip' : 'doing flips',
'walker_flipping' : 'doing backflips',
'walker_backflip' : 'doing backflips',
'walker_one_foot' : 'stand on one foot',
'walker_high_kick' : 'stand up and kick',
'walker_lying_down' : 'lying down horizontally',
'walker_arabesque' : 'arabesque position',
'walker_legs_up' : 'lying down with feet up',
'walker_sit_knees' : 'praying',
'walker_lunge_pose' : 'lunge_pose',
'walker_headstand' : 'headstand',
'kitchen_microwave' : 'opening the microwave fully open',
'kitchen_light' : 'activate the light',
'kitchen_burner' : 'the burner becomes red',
'kitchen_slide' : 'slide cabinet above the knobs',
'kitchen_kettle' : 'pushing up the kettle',
'jaco_reach_top_left' : 'robot grasp the red cube',
'jaco_reach_bottom_left' : 'robot grasp the red cube',
'jaco_reach_top_right' : 'robot grasp the red cube',
'jaco_reach_bottom_right' : 'robot grasp the red cube',
}
class ViCLIPGlobalInstance:
def __init__(self, model='internvideo2'):
self._instantiated = False
self._model = model
def instantiate(self, device='cuda'):
from torchvision.transforms import transforms as vision_transf
import sys
self._instantiated = True
if self._model =='internvideo2':
sys.path.insert(0, str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality/demo/'))
sys.path.insert(0, str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality'))
import numpy as np
from small_config import (Config, eval_dict_leaf)
from small_utils import setup_internvideo2
config = Config.from_file(INTERNVIDEO_PATH / 'InternVideo2/multi_modality/demo/internvideo2_stage2_config.py')
config = eval_dict_leaf(config)
config.model.vision_encoder.num_frames = 8
config.num_frames = 8
config.num_frames_test = 8
# # >> can be configured in case the bert model doesn't load
# config.model.text_encoder.pretrained = str(MODELS_ROOT_PATH / 'bert-large-uncased')
config.model.text_encoder.config = str(INTERNVIDEO_PATH / 'InternVideo2/multi_modality') + "/" + config.model.text_encoder.config
model_pth = str(MODELS_ROOT_PATH / 'InternVideo2-stage2_1b-224p-f4.pt')
config.pretrained_path = model_pth
config['model']['vision_encoder']['pretrained'] = model_pth
intern_model, tokenizer = setup_internvideo2(config)
self.viclip_tokenizer = tokenizer
self.viclip = intern_model
self.viclip.device = device
self.viclip.to(self.viclip.device)
self.viclip.eval()
self.viclip.n_frames = 8
self.viclip.preprocess_transf = vision_transf.Compose([
vision_transf.Resize(size=(224, 224), interpolation=vision_transf.InterpolationMode.BILINEAR),
vision_transf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
sys.path.pop(0)
sys.path.pop(0)
else:
raise NotImplementedError(f"Model {self._model} not implemented")
vid_feat = self.viclip.get_vid_features(torch.zeros(1,self.viclip.n_frames,3,224,224, device=self.viclip.device))
self.viclip_emb_dim = vid_feat.shape[1]
def report_text2video(agent, data,):
report = {}
domain = agent.cfg.task.split('_')[0]
labels_list = DOMAIN2PREDICATES[domain]
wm = world_model = agent.wm
decoder = world_model.heads['decoder'] # B, T, C, H, W
connector = agent.wm.connector
n_frames = connector.n_frames
if hasattr(world_model, 'viclip_model'):
clip = world_model.viclip_model
else:
# Get ViCLIP
viclip_global_instance = globals()['viclip_global_instance']
if not viclip_global_instance._instantiated:
viclip_global_instance.instantiate()
clip = viclip_global_instance.viclip
# Get text(video) embed
text_feat = []
for text in labels_list:
with torch.no_grad():
text_feat.append(clip.get_txt_feat(text,))
text_feat = torch.stack(text_feat, dim=0)
# Check device is right
video_embed = text_feat.to(agent.device)
B = video_embed.shape[0]
# Get actions
video_embed = video_embed.repeat(1,n_frames, 1)
# Imagine
prior = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=False, reset_every_n_frames=False, denoise=True)
prior_recon = decoder(wm.decoder_input_fn(prior))['observation'].mean + 0.5
report[f'text_to_video'] = prior_recon
return report
def max_cosine_similarity(u, v, dim=-1):
max_norm = torch.max(torch.norm(u, dim=dim), torch.norm(v, dim=dim)).unsqueeze(-1)
return torch.sum((u / max_norm) * (v / max_norm), dim=dim)
def neg_mse_fn(a, b, dim=-1, scale=True):
dist = - torch.norm(a - b, dim=dim)
if scale:
dist = dist / np.sqrt(a.shape[-1]).item()
return dist
def compute_reward(agent, agent_seq, target_seq, score_fn='cosine',):
if score_fn in ['cosine', 'max_cosine', 'neg_mse', 'exp_neg_mse']:
distance_fn = dict(cosine=F.cosine_similarity, max_cosine=max_cosine_similarity, neg_mse=neg_mse_fn, exp_neg_mse=neg_mse_fn)[score_fn]
target_stoch = agent.wm.connector.get_stoch( target_seq )
agent_stoch = agent.wm.rssm.get_stoch( agent_seq )
conv_target = agent.wm.heads['decoder']._conv_in[0](target_stoch)
conv_agent = agent.wm.heads['decoder']._conv_in[0](agent_stoch)
reward = distance_fn(conv_target, conv_agent, dim=-1)
if score_fn == 'exp_neg_mse':
reward = torch.exp(reward)
elif score_fn == 'neg_kl':
agent_dist = agent.wm.rssm.get_dist( agent_seq )
target_dist = agent.wm.connector.get_dist( target_seq )
reward = -torch.distributions.kl_divergence(agent_dist, target_dist,)
# scaling factor ( x log x w.r.t. to classes, or just x)
if 'logit' in target_seq:
reward = reward / ( np.log(target_seq['logit'].shape[-1]) * target_seq['logit'].shape[-2] )
else:
reward = reward / target_seq['mean'].shape[-1]
elif score_fn == 'max_like':
agent_dist = agent.wm.rssm.get_dist( agent_seq )
target_sample = target_seq['stoch']
reward = agent_dist.log_prob(target_sample)
elif score_fn == 'combo':
return compute_reward(agent, agent_seq, target_seq, 'cosine') + compute_reward(agent, agent_seq, target_seq, 'neg_kl')
else:
raise NotImplementedError(f"{score_fn} reward not implemented")
return reward
def video_text_reward(agent, seq, score_fn='cosine',
sample_for_target=False, weighted_align=False, align_initial=False, align_sequence=False,
task_prompt='', skip_first_target=False, **kwargs):
wm = world_model = agent.wm
connector = agent.wm.connector
n_frames = connector.n_frames
T, B = seq['deter'].shape[:2]
imagined_steps = T
if not hasattr(agent, 'unconditional_target'):
if hasattr(world_model, 'viclip_model'):
clip = world_model.viclip_model
else:
# Get ViCLIP
viclip_global_instance = globals()['viclip_global_instance']
if not viclip_global_instance._instantiated:
viclip_global_instance.instantiate()
clip = viclip_global_instance.viclip
if task_prompt != '':
task = [task_prompt]
else:
task = [ TASK2PROMPT[agent.cfg.task] ]
# Get text(video) embed
with torch.no_grad():
text_feat = clip.get_txt_feat(task[0],)
# Check device is right
video_embed = text_feat.to(agent.device)
# Unconditional gen
if skip_first_target:
video_embed = video_embed.reshape(1, 1, -1).repeat(B, imagined_steps + 1, 1)
unconditional_stats = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=sample_for_target, reset_every_n_frames=False, denoise=True)
unconditional_stats = { k: v[:,1:].permute([1,0] + list(range(2, len(v.shape)))) for k,v in unconditional_stats.items() }
else:
video_embed = video_embed.reshape(1, 1, -1).repeat(B, imagined_steps, 1)
unconditional_stats = wm.connector.video_imagine(video_embed, dreamer_init=None, sample=sample_for_target, reset_every_n_frames=False, denoise=True)
unconditional_stats = { k: v.permute([1,0] + list(range(2, len(v.shape)))) for k,v in unconditional_stats.items() }
agent.unconditional_target = unconditional_stats
else:
unconditional_stats = agent.unconditional_target
agent_seq = seq
target_seq = unconditional_stats
if align_initial:
assert not align_sequence, 'Cannot align initial and sequence at the same time'
init_seq = { k: v[0] for k,v in target_seq.items() }
init_score = compute_reward(agent, agent_seq, init_seq, score_fn=score_fn,)
if weighted_align:
w = 0.99 * torch.ones_like(init_score, device=init_score.device)
w = torch.cumprod(w, dim=1)
init_score = w * init_score
#
best_indexes_one_hot = F.one_hot(torch.argmax(init_score, dim=0), num_classes=target_seq['stoch'].shape[0])
ts_idx = torch.clip(torch.cumsum(torch.cumsum(best_indexes_one_hot, dim=1), dim=1) - 1, min=0).T
new_target_seq = {}
for k,v in target_seq.items():
if len(v.shape) == 4:
new_ts = ts_idx.unsqueeze(-1).unsqueeze(-1).repeat(1,1, v.shape[-2], v.shape[-1])
else:
new_ts = ts_idx.unsqueeze(-1).repeat(1,1, v.shape[-1])
new_target_seq[k] = torch.gather(v, 0, new_ts) # out[i][j][k] = input[index[i][j][k]][j][k]
return compute_reward(agent, agent_seq, new_target_seq, score_fn=score_fn,).unsqueeze(-1)
elif align_sequence:
align_score = []
get_prev_a_b = lambda d, a, b : { k : v[a:b] for k,v in d.items() }
shorter_target_seq = get_prev_a_b(unconditional_stats, 0, n_frames)
for t in range(T-n_frames):
cur_agent_seq = get_prev_a_b(seq, t, t+n_frames)
score = compute_reward(agent, cur_agent_seq, shorter_target_seq, score_fn=score_fn,).mean(dim=0) # 0 is time dimension
align_score.append(score)
align_score = torch.stack(align_score, dim=0)
if weighted_align:
w = 0.99 * torch.ones_like(align_score, device=align_score.device)
w = torch.cumprod(w, dim=1)
align_score = w * align_score
best_indexes_one_hot = F.one_hot(torch.argmax(align_score, dim=0), num_classes=target_seq['stoch'].shape[0])
ts_idx = torch.clip(torch.cumsum(torch.cumsum(best_indexes_one_hot, dim=1), dim=1) - 1, min=0).T
new_target_seq = {}
for k,v in target_seq.items():
if len(v.shape) == 4:
new_ts = ts_idx.unsqueeze(-1).unsqueeze(-1).repeat(1,1, v.shape[-2], v.shape[-1])
else:
new_ts = ts_idx.unsqueeze(-1).repeat(1,1, v.shape[-1])
new_target_seq[k] = torch.gather(v, 0, new_ts) # out[i][j][k] = input[index[i][j][k]][j][k]
return compute_reward(agent, agent_seq, new_target_seq, score_fn=score_fn,).unsqueeze(-1)
else:
neg_kl = compute_reward(agent, agent_seq, target_seq, score_fn=score_fn,)
return neg_kl.unsqueeze(-1)
global viclip_global_instance
viclip_global_instance = ViCLIPGlobalInstance() |