Upload 35 files
Browse files
delta-iris/src/models/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .tokenizer import Tokenizer
|
delta-iris/src/models/convnet.py
CHANGED
|
@@ -7,25 +7,16 @@ import torch.nn as nn
|
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
|
| 10 |
-
@dataclass
|
| 11 |
-
class FrameCnnConfig:
|
| 12 |
-
image_channels: int
|
| 13 |
-
latent_dim: int
|
| 14 |
-
num_channels: int
|
| 15 |
-
mult: List[int]
|
| 16 |
-
down: List[int]
|
| 17 |
-
|
| 18 |
-
|
| 19 |
class FrameEncoder(nn.Module):
|
| 20 |
-
def __init__(self, config:
|
| 21 |
super().__init__()
|
| 22 |
|
| 23 |
-
assert len(config
|
| 24 |
-
encoder_layers = [nn.Conv2d(config
|
| 25 |
-
input_channels = config
|
| 26 |
|
| 27 |
-
for m, d in zip(config
|
| 28 |
-
output_channels = m * config
|
| 29 |
encoder_layers.append(ResidualBlock(input_channels, output_channels))
|
| 30 |
input_channels = output_channels
|
| 31 |
if d:
|
|
@@ -33,7 +24,7 @@ class FrameEncoder(nn.Module):
|
|
| 33 |
encoder_layers.extend([
|
| 34 |
nn.GroupNorm(num_groups=32, num_channels=input_channels),
|
| 35 |
nn.SiLU(inplace=True),
|
| 36 |
-
nn.Conv2d(input_channels, config
|
| 37 |
])
|
| 38 |
self.encoder = nn.Sequential(*encoder_layers)
|
| 39 |
|
|
@@ -47,25 +38,25 @@ class FrameEncoder(nn.Module):
|
|
| 47 |
|
| 48 |
|
| 49 |
class FrameDecoder(nn.Module):
|
| 50 |
-
def __init__(self, config:
|
| 51 |
super().__init__()
|
| 52 |
|
| 53 |
-
assert len(config
|
| 54 |
decoder_layers = []
|
| 55 |
-
output_channels = config
|
| 56 |
|
| 57 |
-
for m, d in zip(config
|
| 58 |
-
input_channels = m * config
|
| 59 |
decoder_layers.append(ResidualBlock(input_channels, output_channels))
|
| 60 |
output_channels = input_channels
|
| 61 |
if d:
|
| 62 |
decoder_layers.append(Upsample(input_channels))
|
| 63 |
decoder_layers.reverse()
|
| 64 |
-
decoder_layers.insert(0, nn.Conv2d(config
|
| 65 |
decoder_layers.extend([
|
| 66 |
-
nn.GroupNorm(num_groups=32, num_channels=config
|
| 67 |
nn.SiLU(inplace=True),
|
| 68 |
-
nn.Conv2d(config
|
| 69 |
])
|
| 70 |
self.decoder = nn.Sequential(*decoder_layers)
|
| 71 |
|
|
|
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class FrameEncoder(nn.Module):
|
| 11 |
+
def __init__(self, config: dict) -> None:
|
| 12 |
super().__init__()
|
| 13 |
|
| 14 |
+
assert len(config["mult"]) == len(config["down"])
|
| 15 |
+
encoder_layers = [nn.Conv2d(config["image_channels"], config["num_channels"], kernel_size=3, stride=1, padding=1)]
|
| 16 |
+
input_channels = config["num_channels"]
|
| 17 |
|
| 18 |
+
for m, d in zip(config["mult"], config["down"]):
|
| 19 |
+
output_channels = m * config["num_channels"]
|
| 20 |
encoder_layers.append(ResidualBlock(input_channels, output_channels))
|
| 21 |
input_channels = output_channels
|
| 22 |
if d:
|
|
|
|
| 24 |
encoder_layers.extend([
|
| 25 |
nn.GroupNorm(num_groups=32, num_channels=input_channels),
|
| 26 |
nn.SiLU(inplace=True),
|
| 27 |
+
nn.Conv2d(input_channels, config["latent_dim"], kernel_size=3, stride=1, padding=1)
|
| 28 |
])
|
| 29 |
self.encoder = nn.Sequential(*encoder_layers)
|
| 30 |
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
class FrameDecoder(nn.Module):
|
| 41 |
+
def __init__(self, config: dict) -> None:
|
| 42 |
super().__init__()
|
| 43 |
|
| 44 |
+
assert len(config["mult"]) == len(config["down"])
|
| 45 |
decoder_layers = []
|
| 46 |
+
output_channels = config["num_channels"]
|
| 47 |
|
| 48 |
+
for m, d in zip(config["mult"], config["down"]):
|
| 49 |
+
input_channels = m * config["num_channels"]
|
| 50 |
decoder_layers.append(ResidualBlock(input_channels, output_channels))
|
| 51 |
output_channels = input_channels
|
| 52 |
if d:
|
| 53 |
decoder_layers.append(Upsample(input_channels))
|
| 54 |
decoder_layers.reverse()
|
| 55 |
+
decoder_layers.insert(0, nn.Conv2d(config["latent_dim"], input_channels, kernel_size=3, stride=1, padding=1))
|
| 56 |
decoder_layers.extend([
|
| 57 |
+
nn.GroupNorm(num_groups=32, num_channels=config["num_channels"]),
|
| 58 |
nn.SiLU(inplace=True),
|
| 59 |
+
nn.Conv2d(config["num_channels"], config["image_channels"], kernel_size=3, stride=1, padding=1)
|
| 60 |
])
|
| 61 |
self.decoder = nn.Sequential(*decoder_layers)
|
| 62 |
|
delta-iris/src/models/tokenizer/__init__.py
CHANGED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
|
|
|
|
|
|
delta-iris/src/models/transformer.py
CHANGED
|
@@ -11,58 +11,36 @@ import torch.nn as nn
|
|
| 11 |
|
| 12 |
from .kv_caching import KeysValues, KVCache
|
| 13 |
|
| 14 |
-
|
| 15 |
-
@dataclass
|
| 16 |
-
class TransformerConfig:
|
| 17 |
-
|
| 18 |
-
tokens_per_block: int
|
| 19 |
-
max_blocks: int
|
| 20 |
-
|
| 21 |
-
num_layers: int
|
| 22 |
-
num_heads: int
|
| 23 |
-
embed_dim: int
|
| 24 |
-
|
| 25 |
-
attention: str
|
| 26 |
-
|
| 27 |
-
embed_pdrop: float
|
| 28 |
-
resid_pdrop: float
|
| 29 |
-
attn_pdrop: float
|
| 30 |
-
|
| 31 |
-
@property
|
| 32 |
-
def max_tokens(self):
|
| 33 |
-
return self.tokens_per_block * self.max_blocks
|
| 34 |
-
|
| 35 |
-
|
| 36 |
class TransformerEncoder(nn.Module):
|
| 37 |
-
def __init__(self, config:
|
| 38 |
super().__init__()
|
| 39 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
self.ln = nn.LayerNorm(config.embed_dim)
|
| 44 |
-
|
| 45 |
-
assert config.attention in ('causal', 'block_causal')
|
| 46 |
-
k, m = config.tokens_per_block, config.max_blocks
|
| 47 |
mask_sa = torch.tril(torch.ones(k * m, k * m))
|
| 48 |
-
if config
|
| 49 |
mask_sa = torch.max(mask_sa, torch.block_diag(*[torch.ones(k, k) for _ in range(m)]))
|
| 50 |
mask_sa = mask_sa.bool()
|
| 51 |
|
| 52 |
-
self.blocks = nn.ModuleList([EncoderLayer(config, mask_sa) for _ in range(config
|
| 53 |
self.keys_values = None
|
| 54 |
|
| 55 |
@property
|
| 56 |
def num_blocks_left_in_kv_cache(self) -> float:
|
| 57 |
assert self.keys_values is not None
|
| 58 |
-
return (self.config
|
| 59 |
|
| 60 |
def reset_kv_cache(self, n: int) -> None:
|
| 61 |
device = self.ln.weight.device
|
| 62 |
-
self.keys_values = KeysValues(n, self.config
|
| 63 |
|
| 64 |
def forward(self, x: torch.FloatTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
|
| 65 |
-
assert x.ndim == 3 and x.size(2) == self.config
|
| 66 |
|
| 67 |
prev_steps = self.keys_values.size if use_kv_cache else 0
|
| 68 |
inputs = x + self.pos_emb(prev_steps + torch.arange(x.size(1), device=x.device))
|
|
@@ -76,7 +54,7 @@ class TransformerEncoder(nn.Module):
|
|
| 76 |
|
| 77 |
|
| 78 |
class EncoderLayer(nn.Module):
|
| 79 |
-
def __init__(self, config:
|
| 80 |
super().__init__()
|
| 81 |
self.sa = SelfAttentionLayer(config, mask=mask_sa)
|
| 82 |
self.mlp = MLPLayer(config)
|
|
@@ -86,14 +64,14 @@ class EncoderLayer(nn.Module):
|
|
| 86 |
|
| 87 |
|
| 88 |
class MLPLayer(nn.Module):
|
| 89 |
-
def __init__(self, config:
|
| 90 |
super().__init__()
|
| 91 |
-
self.ln = nn.LayerNorm(config
|
| 92 |
self.mlp = nn.Sequential(
|
| 93 |
-
nn.Linear(config
|
| 94 |
nn.GELU(),
|
| 95 |
-
nn.Linear(4 * config
|
| 96 |
-
nn.Dropout(config
|
| 97 |
)
|
| 98 |
|
| 99 |
def forward(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
|
|
@@ -101,13 +79,13 @@ class MLPLayer(nn.Module):
|
|
| 101 |
|
| 102 |
|
| 103 |
class SelfAttentionLayer(nn.Module):
|
| 104 |
-
def __init__(self, config:
|
| 105 |
super().__init__()
|
| 106 |
self.register_buffer('mask', mask)
|
| 107 |
-
self.ln = nn.LayerNorm(config
|
| 108 |
-
self.query = nn.Linear(config
|
| 109 |
-
self.key = nn.Linear(config
|
| 110 |
-
self.value = nn.Linear(config
|
| 111 |
self.attention = Attention(config)
|
| 112 |
|
| 113 |
def forward(self, inputs: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
|
|
@@ -134,13 +112,13 @@ class SelfAttentionLayer(nn.Module):
|
|
| 134 |
|
| 135 |
|
| 136 |
class Attention(nn.Module):
|
| 137 |
-
def __init__(self, config:
|
| 138 |
super().__init__()
|
| 139 |
-
assert config
|
| 140 |
-
self.num_heads = config
|
| 141 |
-
self.attn_pdrop = config
|
| 142 |
-
self.resid_drop = nn.Dropout(config
|
| 143 |
-
self.proj = nn.Linear(config
|
| 144 |
|
| 145 |
def forward(self, q: torch.FloatTensor, k: torch.FloatTensor, v: torch.FloatTensor, mask: torch.BoolTensor) -> torch.FloatTensor:
|
| 146 |
assert mask.size(0) == q.size(1) and mask.size(1) == k.size(1)
|
|
|
|
| 11 |
|
| 12 |
from .kv_caching import KeysValues, KVCache
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class TransformerEncoder(nn.Module):
|
| 15 |
+
def __init__(self, config: dict) -> None:
|
| 16 |
super().__init__()
|
| 17 |
self.config = config
|
| 18 |
+
self.config["max_tokens"] = config["tokens_per_block"] * config["max_blocks"]
|
| 19 |
+
self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"])
|
| 20 |
+
self.emb_drop = nn.Dropout(config["embed_pdrop"])
|
| 21 |
+
self.ln = nn.LayerNorm(config["embed_dim"])
|
| 22 |
|
| 23 |
+
assert config["attention"] in ('causal', 'block_causal')
|
| 24 |
+
k, m = config["tokens_per_block"], config["max_blocks"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
mask_sa = torch.tril(torch.ones(k * m, k * m))
|
| 26 |
+
if config["attention"] == 'block_causal':
|
| 27 |
mask_sa = torch.max(mask_sa, torch.block_diag(*[torch.ones(k, k) for _ in range(m)]))
|
| 28 |
mask_sa = mask_sa.bool()
|
| 29 |
|
| 30 |
+
self.blocks = nn.ModuleList([EncoderLayer(config, mask_sa) for _ in range(config["num_layers"])])
|
| 31 |
self.keys_values = None
|
| 32 |
|
| 33 |
@property
|
| 34 |
def num_blocks_left_in_kv_cache(self) -> float:
|
| 35 |
assert self.keys_values is not None
|
| 36 |
+
return (self.config["max_tokens"] - self.keys_values.size) / self.config["tokens_per_block"]
|
| 37 |
|
| 38 |
def reset_kv_cache(self, n: int) -> None:
|
| 39 |
device = self.ln.weight.device
|
| 40 |
+
self.keys_values = KeysValues(n, self.config["max_tokens"], self.config["embed_dim"], self.config["num_layers"], device)
|
| 41 |
|
| 42 |
def forward(self, x: torch.FloatTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
|
| 43 |
+
assert x.ndim == 3 and x.size(2) == self.config["embed_dim"] # (B, TK, E)
|
| 44 |
|
| 45 |
prev_steps = self.keys_values.size if use_kv_cache else 0
|
| 46 |
inputs = x + self.pos_emb(prev_steps + torch.arange(x.size(1), device=x.device))
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
class EncoderLayer(nn.Module):
|
| 57 |
+
def __init__(self, config: dict, mask_sa: torch.LongTensor) -> None:
|
| 58 |
super().__init__()
|
| 59 |
self.sa = SelfAttentionLayer(config, mask=mask_sa)
|
| 60 |
self.mlp = MLPLayer(config)
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class MLPLayer(nn.Module):
|
| 67 |
+
def __init__(self, config: dict) -> None:
|
| 68 |
super().__init__()
|
| 69 |
+
self.ln = nn.LayerNorm(config["embed_dim"])
|
| 70 |
self.mlp = nn.Sequential(
|
| 71 |
+
nn.Linear(config["embed_dim"], 4 * config["embed_dim"]),
|
| 72 |
nn.GELU(),
|
| 73 |
+
nn.Linear(4 * config["embed_dim"], config["embed_dim"]),
|
| 74 |
+
nn.Dropout(config["resid_pdrop"]),
|
| 75 |
)
|
| 76 |
|
| 77 |
def forward(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
class SelfAttentionLayer(nn.Module):
|
| 82 |
+
def __init__(self, config: dict, mask: torch.BoolTensor) -> None:
|
| 83 |
super().__init__()
|
| 84 |
self.register_buffer('mask', mask)
|
| 85 |
+
self.ln = nn.LayerNorm(config["embed_dim"])
|
| 86 |
+
self.query = nn.Linear(config["embed_dim"], config["embed_dim"])
|
| 87 |
+
self.key = nn.Linear(config["embed_dim"], config["embed_dim"])
|
| 88 |
+
self.value = nn.Linear(config["embed_dim"], config["embed_dim"])
|
| 89 |
self.attention = Attention(config)
|
| 90 |
|
| 91 |
def forward(self, inputs: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
class Attention(nn.Module):
|
| 115 |
+
def __init__(self, config: dict) -> None:
|
| 116 |
super().__init__()
|
| 117 |
+
assert config["embed_dim"] % config["num_heads"] == 0
|
| 118 |
+
self.num_heads = config["num_heads"]
|
| 119 |
+
self.attn_pdrop = config["attn_pdrop"]
|
| 120 |
+
self.resid_drop = nn.Dropout(config["resid_pdrop"])
|
| 121 |
+
self.proj = nn.Linear(config["embed_dim"], config["embed_dim"])
|
| 122 |
|
| 123 |
def forward(self, q: torch.FloatTensor, k: torch.FloatTensor, v: torch.FloatTensor, mask: torch.BoolTensor) -> torch.FloatTensor:
|
| 124 |
assert mask.size(0) == q.size(1) and mask.size(1) == k.size(1)
|
delta-iris/src/tokenizer.py
CHANGED
|
@@ -6,48 +6,32 @@ from einops import rearrange
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
-
from .models.convnet import
|
| 10 |
from .data import Batch
|
| 11 |
from .models.tokenizer.quantizer import Quantizer, QuantizerOutput
|
| 12 |
from .models.utils import init_weights, LossWithIntermediateLosses
|
| 13 |
|
| 14 |
-
|
| 15 |
-
@dataclass
|
| 16 |
-
class TokenizerConfig:
|
| 17 |
-
image_channels: int
|
| 18 |
-
image_size: int
|
| 19 |
-
num_actions: int
|
| 20 |
-
num_tokens: int
|
| 21 |
-
decoder_act_channels: int
|
| 22 |
-
codebook_size: int
|
| 23 |
-
codebook_dim: int
|
| 24 |
-
max_codebook_updates_with_revival: int
|
| 25 |
-
encoder_config: FrameCnnConfig
|
| 26 |
-
decoder_config: FrameCnnConfig
|
| 27 |
-
frame_cnn_config: FrameCnnConfig
|
| 28 |
-
|
| 29 |
-
|
| 30 |
class Tokenizer(nn.Module):
|
| 31 |
-
def __init__(self, config:
|
| 32 |
super().__init__()
|
| 33 |
self.config = config
|
| 34 |
|
| 35 |
-
self.latent_res = config
|
| 36 |
-
self.tokens_grid_res = int(math.sqrt(config
|
| 37 |
self.token_res = self.latent_res // self.tokens_grid_res
|
| 38 |
|
| 39 |
-
self.encoder_act_emb = nn.Embedding(config
|
| 40 |
-
self.decoder_act_emb = nn.Embedding(config
|
| 41 |
|
| 42 |
self.quantizer = Quantizer(
|
| 43 |
-
config
|
| 44 |
-
input_dim=config
|
| 45 |
-
max_codebook_updates_with_revival=config
|
| 46 |
)
|
| 47 |
|
| 48 |
-
self.encoder = FrameEncoder(config
|
| 49 |
-
self.decoder = FrameDecoder(config
|
| 50 |
-
self.frame_cnn = FrameEncoder(config
|
| 51 |
|
| 52 |
self.apply(init_weights)
|
| 53 |
|
|
@@ -89,7 +73,7 @@ class Tokenizer(nn.Module):
|
|
| 89 |
|
| 90 |
def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor:
|
| 91 |
x1_emb = self.frame_cnn(x1)
|
| 92 |
-
a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config
|
| 93 |
|
| 94 |
decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2)
|
| 95 |
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
+
from .models.convnet import FrameEncoder, FrameDecoder
|
| 10 |
from .data import Batch
|
| 11 |
from .models.tokenizer.quantizer import Quantizer, QuantizerOutput
|
| 12 |
from .models.utils import init_weights, LossWithIntermediateLosses
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class Tokenizer(nn.Module):
|
| 15 |
+
def __init__(self, config: dict) -> None:
|
| 16 |
super().__init__()
|
| 17 |
self.config = config
|
| 18 |
|
| 19 |
+
self.latent_res = config["image_size"] // 2 ** sum(config["encoder_config"]["down"])
|
| 20 |
+
self.tokens_grid_res = int(math.sqrt(config["num_tokens"]))
|
| 21 |
self.token_res = self.latent_res // self.tokens_grid_res
|
| 22 |
|
| 23 |
+
self.encoder_act_emb = nn.Embedding(config["num_actions"], config["image_size"] ** 2)
|
| 24 |
+
self.decoder_act_emb = nn.Embedding(config["num_actions"], config["decoder_act_channels"] * self.latent_res ** 2)
|
| 25 |
|
| 26 |
self.quantizer = Quantizer(
|
| 27 |
+
config["codebook_size"], config["codebook_dim"],
|
| 28 |
+
input_dim=config["encoder_config"]["latent_dim"] * self.token_res ** 2,
|
| 29 |
+
max_codebook_updates_with_revival=config["max_codebook_updates_with_revival"]
|
| 30 |
)
|
| 31 |
|
| 32 |
+
self.encoder = FrameEncoder(config["encoder_config"])
|
| 33 |
+
self.decoder = FrameDecoder(config["decoder_config"])
|
| 34 |
+
self.frame_cnn = FrameEncoder(config["frame_cnn_config"])
|
| 35 |
|
| 36 |
self.apply(init_weights)
|
| 37 |
|
|
|
|
| 73 |
|
| 74 |
def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor:
|
| 75 |
x1_emb = self.frame_cnn(x1)
|
| 76 |
+
a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config["decoder_act_channels"], h=x1_emb.size(3))
|
| 77 |
|
| 78 |
decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2)
|
| 79 |
|
delta-iris/src/world_model.py
CHANGED
|
@@ -6,77 +6,54 @@ import torch
|
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
-
from .models.convnet import
|
| 10 |
from .data import Batch
|
| 11 |
from .models.slicer import Head
|
| 12 |
from .tokenizer import Tokenizer
|
| 13 |
-
from .models.transformer import TransformerEncoder
|
| 14 |
from .models.utils import init_weights, LossWithIntermediateLosses, symlog, two_hot
|
| 15 |
|
| 16 |
-
|
| 17 |
-
@dataclass
|
| 18 |
-
class WorldModelOutput:
|
| 19 |
-
output_sequence: torch.FloatTensor
|
| 20 |
-
logits_latents: torch.FloatTensor
|
| 21 |
-
logits_rewards: torch.FloatTensor
|
| 22 |
-
logits_ends: torch.FloatTensor
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@dataclass
|
| 26 |
-
class WorldModelConfig:
|
| 27 |
-
latent_vocab_size: int
|
| 28 |
-
num_actions: int
|
| 29 |
-
image_channels: int
|
| 30 |
-
image_size: int
|
| 31 |
-
latents_weight: float
|
| 32 |
-
rewards_weight: float
|
| 33 |
-
ends_weight: float
|
| 34 |
-
two_hot_rews: bool
|
| 35 |
-
transformer_config: TransformerConfig
|
| 36 |
-
frame_cnn_config: FrameCnnConfig
|
| 37 |
-
|
| 38 |
-
|
| 39 |
class WorldModel(nn.Module):
|
| 40 |
-
def __init__(self, config:
|
| 41 |
super().__init__()
|
| 42 |
self.config = config
|
| 43 |
-
self.transformer = TransformerEncoder(config
|
| 44 |
|
| 45 |
-
assert ((config
|
| 46 |
-
self.frame_cnn = nn.Sequential(FrameEncoder(config
|
| 47 |
|
| 48 |
-
self.act_emb = nn.Embedding(config
|
| 49 |
-
self.latents_emb = nn.Embedding(config
|
| 50 |
|
| 51 |
-
act_pattern = torch.zeros(config
|
| 52 |
act_pattern[1] = 1
|
| 53 |
-
act_and_latents_but_last_pattern = torch.zeros(config
|
| 54 |
act_and_latents_but_last_pattern[1:-1] = 1
|
| 55 |
|
| 56 |
self.head_latents = Head(
|
| 57 |
-
max_blocks=config
|
| 58 |
block_mask=act_and_latents_but_last_pattern,
|
| 59 |
head_module=nn.Sequential(
|
| 60 |
-
nn.Linear(config
|
| 61 |
-
nn.Linear(config
|
| 62 |
)
|
| 63 |
)
|
| 64 |
|
| 65 |
self.head_rewards = Head(
|
| 66 |
-
max_blocks=config
|
| 67 |
block_mask=act_pattern,
|
| 68 |
head_module=nn.Sequential(
|
| 69 |
-
nn.Linear(config
|
| 70 |
-
nn.Linear(config
|
| 71 |
)
|
| 72 |
)
|
| 73 |
|
| 74 |
self.head_ends = Head(
|
| 75 |
-
max_blocks=config
|
| 76 |
block_mask=act_pattern,
|
| 77 |
head_module=nn.Sequential(
|
| 78 |
-
nn.Linear(config
|
| 79 |
-
nn.Linear(config
|
| 80 |
)
|
| 81 |
)
|
| 82 |
|
|
@@ -85,7 +62,7 @@ class WorldModel(nn.Module):
|
|
| 85 |
def __repr__(self) -> str:
|
| 86 |
return "world_model"
|
| 87 |
|
| 88 |
-
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) ->
|
| 89 |
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
|
| 90 |
num_steps = sequence.size(1)
|
| 91 |
|
|
@@ -95,7 +72,12 @@ class WorldModel(nn.Module):
|
|
| 95 |
logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
|
| 96 |
logits_ends = self.head_ends(outputs, num_steps, prev_steps)
|
| 97 |
|
| 98 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs) -> LossWithIntermediateLosses:
|
| 101 |
assert torch.all(batch.ends.sum(dim=1) <= 1)
|
|
@@ -117,11 +99,11 @@ class WorldModel(nn.Module):
|
|
| 117 |
labels_latents = latent_tokens[mask[:, :-1]].flatten()
|
| 118 |
logits_latents = outputs.logits_latents[:, :-k][repeat(mask[:, :-1], 'b t -> b (t k)', k=k)]
|
| 119 |
latent_acc = (logits_latents.max(dim=-1)[1] == labels_latents).float().mean()
|
| 120 |
-
labels_rewards = two_hot(symlog(batch.rewards)) if self.config
|
| 121 |
|
| 122 |
-
loss_latents = F.cross_entropy(logits_latents, target=labels_latents) * self.config
|
| 123 |
-
loss_rewards = F.cross_entropy(outputs.logits_rewards[mask], target=labels_rewards[mask]) * self.config
|
| 124 |
-
loss_ends = F.cross_entropy(outputs.logits_ends[mask], target=batch.ends[mask]) * self.config
|
| 125 |
|
| 126 |
return LossWithIntermediateLosses(loss_latents=loss_latents, loss_rewards=loss_rewards, loss_ends=loss_ends), {'latent_accuracy': latent_acc}
|
| 127 |
|
|
|
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
+
from .models.convnet import FrameEncoder
|
| 10 |
from .data import Batch
|
| 11 |
from .models.slicer import Head
|
| 12 |
from .tokenizer import Tokenizer
|
| 13 |
+
from .models.transformer import TransformerEncoder
|
| 14 |
from .models.utils import init_weights, LossWithIntermediateLosses, symlog, two_hot
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class WorldModel(nn.Module):
|
| 17 |
+
def __init__(self, config: dict) -> None:
|
| 18 |
super().__init__()
|
| 19 |
self.config = config
|
| 20 |
+
self.transformer = TransformerEncoder(config["transformer_config"])
|
| 21 |
|
| 22 |
+
assert ((config["image_size"] // 2 ** sum(config["frame_cnn_config"]["down"])) ** 2) * config["frame_cnn_config"]["latent_dim"] == config["transformer_config"]["embed_dim"]
|
| 23 |
+
self.frame_cnn = nn.Sequential(FrameEncoder(config["frame_cnn_config"]), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config["transformer_config"]["embed_dim"]))
|
| 24 |
|
| 25 |
+
self.act_emb = nn.Embedding(config["num_actions"], config["transformer_config"]["embed_dim"])
|
| 26 |
+
self.latents_emb = nn.Embedding(config["latent_vocab_size"], config["transformer_config"]["embed_dim"])
|
| 27 |
|
| 28 |
+
act_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
|
| 29 |
act_pattern[1] = 1
|
| 30 |
+
act_and_latents_but_last_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
|
| 31 |
act_and_latents_but_last_pattern[1:-1] = 1
|
| 32 |
|
| 33 |
self.head_latents = Head(
|
| 34 |
+
max_blocks=config["transformer_config"]["max_blocks"],
|
| 35 |
block_mask=act_and_latents_but_last_pattern,
|
| 36 |
head_module=nn.Sequential(
|
| 37 |
+
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
|
| 38 |
+
nn.Linear(config["transformer_config"]["embed_dim"], config["latent_vocab_size"])
|
| 39 |
)
|
| 40 |
)
|
| 41 |
|
| 42 |
self.head_rewards = Head(
|
| 43 |
+
max_blocks=config["transformer_config"]["max_blocks"],
|
| 44 |
block_mask=act_pattern,
|
| 45 |
head_module=nn.Sequential(
|
| 46 |
+
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
|
| 47 |
+
nn.Linear(config["transformer_config"]["embed_dim"], 255 if config["two_hot_rews"] else 3)
|
| 48 |
)
|
| 49 |
)
|
| 50 |
|
| 51 |
self.head_ends = Head(
|
| 52 |
+
max_blocks=config["transformer_config"]["max_blocks"],
|
| 53 |
block_mask=act_pattern,
|
| 54 |
head_module=nn.Sequential(
|
| 55 |
+
nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
|
| 56 |
+
nn.Linear(config["transformer_config"]["embed_dim"], 2)
|
| 57 |
)
|
| 58 |
)
|
| 59 |
|
|
|
|
| 62 |
def __repr__(self) -> str:
|
| 63 |
return "world_model"
|
| 64 |
|
| 65 |
+
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
|
| 66 |
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
|
| 67 |
num_steps = sequence.size(1)
|
| 68 |
|
|
|
|
| 72 |
logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
|
| 73 |
logits_ends = self.head_ends(outputs, num_steps, prev_steps)
|
| 74 |
|
| 75 |
+
return {
|
| 76 |
+
"output_sequence": outputs,
|
| 77 |
+
"logits_latents": logits_latents,
|
| 78 |
+
"logits_rewards": logits_rewards,
|
| 79 |
+
"logits_ends": logits_ends
|
| 80 |
+
}
|
| 81 |
|
| 82 |
def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs) -> LossWithIntermediateLosses:
|
| 83 |
assert torch.all(batch.ends.sum(dim=1) <= 1)
|
|
|
|
| 99 |
labels_latents = latent_tokens[mask[:, :-1]].flatten()
|
| 100 |
logits_latents = outputs.logits_latents[:, :-k][repeat(mask[:, :-1], 'b t -> b (t k)', k=k)]
|
| 101 |
latent_acc = (logits_latents.max(dim=-1)[1] == labels_latents).float().mean()
|
| 102 |
+
labels_rewards = two_hot(symlog(batch.rewards)) if self.config["two_hot_rews"] else (batch.rewards.sign() + 1).long()
|
| 103 |
|
| 104 |
+
loss_latents = F.cross_entropy(logits_latents, target=labels_latents) * self.config["latents_weight"]
|
| 105 |
+
loss_rewards = F.cross_entropy(outputs.logits_rewards[mask], target=labels_rewards[mask]) * self.config["rewards_weight"]
|
| 106 |
+
loss_ends = F.cross_entropy(outputs.logits_ends[mask], target=batch.ends[mask]) * self.config["ends_weight"]
|
| 107 |
|
| 108 |
return LossWithIntermediateLosses(loss_latents=loss_latents, loss_rewards=loss_rewards, loss_ends=loss_ends), {'latent_accuracy': latent_acc}
|
| 109 |
|