File size: 4,388 Bytes
d899b9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from collections import defaultdict

import torch
import torch.nn.functional as F


@torch.no_grad()
def log_sample_res(
    text_encoder, vision_encoder, rdt, args, 
    accelerator, weight_dtype, dataset_id2name, dataloader, logger
):
    logger.info(
        f"Running sampling for {args.num_sample_batches} batches..."
    )

    rdt.eval()
    
    loss_for_log = defaultdict(float)
    loss_counter = defaultdict(int)
    for step, batch in enumerate(dataloader):
        if step >= args.num_sample_batches:
            break
        
        data_indices = batch["data_indices"]
        ctrl_freqs = batch["ctrl_freqs"]
        state_norm = batch["state_norm"].to(dtype=weight_dtype)
        images = batch["images"].to(dtype=weight_dtype)
        states = batch["states"].to(dtype=weight_dtype)
        # We only use the last state as input
        states = states[:, -1:, :]
        actions = batch["actions"].to(dtype=weight_dtype)
        state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
            
        batch_size, _, C, H, W = images.shape
        image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
        image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
        
        lang_attn_mask = batch["lang_attn_mask"]
        text_embeds = batch["lang_embeds"].to(dtype=weight_dtype) \
            if args.precomp_lang_embed \
            else text_encoder(
                input_ids=batch["input_ids"],
                attention_mask=lang_attn_mask
            )["last_hidden_state"].detach()
            
        with torch.autocast(device_type='cuda',dtype=torch.bfloat16):
            pred_actions = rdt.predict_action(
                lang_tokens=text_embeds,
                lang_attn_mask=lang_attn_mask,
                img_tokens=image_embeds,
                state_tokens=states,
                action_mask=state_elem_mask.unsqueeze(1),
                ctrl_freqs=ctrl_freqs
            )
        
        num_steps = pred_actions.shape[1]
        expanded_state_elem_mask = state_elem_mask.unsqueeze(1).tile((1, num_steps, 1)).float()
        expanded_state_norm = state_norm.unsqueeze(1).tile((1, num_steps, 1)).float()
        
        loss = F.mse_loss(pred_actions, actions, reduction='none').float()
        
        mse_loss_per_entry = ((loss * expanded_state_elem_mask).reshape((batch_size, -1)).sum(1)
                            / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1))
        l2_loss_per_entry = loss.sqrt() / (expanded_state_norm + 1e-3)
        l2_loss_per_entry = ((l2_loss_per_entry * expanded_state_elem_mask).reshape((batch_size, -1)).sum(1)
                        / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1))

        dataset_indices, mse_losses, l2_losses = accelerator.gather_for_metrics(
            (torch.LongTensor(data_indices).to(device=pred_actions.device), 
             mse_loss_per_entry, l2_loss_per_entry),
        ) 
        dataset_indices = dataset_indices.tolist()
        if accelerator.is_main_process:
            for loss_suffix, losses in zip(["_sample_mse", "_sample_l2err"], [mse_losses, l2_losses]):
                for dataset_idx, loss_tensor in zip(dataset_indices, losses):
                    loss_name = dataset_id2name[dataset_idx] + loss_suffix
                    loss_for_log[loss_name] += loss_tensor.item()
                    loss_counter[loss_name] += 1
        
        mse_loss = (loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
        mse_loss_scaler = accelerator.gather(mse_loss).mean().item()
        loss_for_log["overall_avg_sample_mse"] += mse_loss_scaler
        
        l2_loss = loss.sqrt() / (expanded_state_norm + 1e-3)
        l2_loss = (l2_loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
        l2_loss_scaler = accelerator.gather(l2_loss).mean().item()
        loss_for_log["overall_avg_sample_l2err"] += l2_loss_scaler

    for name in loss_for_log:
        if name in ["overall_avg_sample_mse", "overall_avg_sample_l2err"]:
            loss_scaler = loss_for_log[name]
            loss_for_log[name] = round(loss_scaler / (args.num_sample_batches), 4)
        else:
            loss_for_log[name] = round(loss_for_log[name] / loss_counter[name], 4)
    
    rdt.train()
    torch.cuda.empty_cache()

    return dict(loss_for_log)