ShaswatRobotics commited on
Commit
fb56df2
·
verified ·
1 Parent(s): 5553f7f

Upload 35 files

Browse files
delta-iris/src/models/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .tokenizer import Tokenizer
delta-iris/src/models/convnet.py CHANGED
@@ -7,25 +7,16 @@ import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
 
10
- @dataclass
11
- class FrameCnnConfig:
12
- image_channels: int
13
- latent_dim: int
14
- num_channels: int
15
- mult: List[int]
16
- down: List[int]
17
-
18
-
19
  class FrameEncoder(nn.Module):
20
- def __init__(self, config: FrameCnnConfig) -> None:
21
  super().__init__()
22
 
23
- assert len(config.mult) == len(config.down)
24
- encoder_layers = [nn.Conv2d(config.image_channels, config.num_channels, kernel_size=3, stride=1, padding=1)]
25
- input_channels = config.num_channels
26
 
27
- for m, d in zip(config.mult, config.down):
28
- output_channels = m * config.num_channels
29
  encoder_layers.append(ResidualBlock(input_channels, output_channels))
30
  input_channels = output_channels
31
  if d:
@@ -33,7 +24,7 @@ class FrameEncoder(nn.Module):
33
  encoder_layers.extend([
34
  nn.GroupNorm(num_groups=32, num_channels=input_channels),
35
  nn.SiLU(inplace=True),
36
- nn.Conv2d(input_channels, config.latent_dim, kernel_size=3, stride=1, padding=1)
37
  ])
38
  self.encoder = nn.Sequential(*encoder_layers)
39
 
@@ -47,25 +38,25 @@ class FrameEncoder(nn.Module):
47
 
48
 
49
  class FrameDecoder(nn.Module):
50
- def __init__(self, config: FrameCnnConfig) -> None:
51
  super().__init__()
52
 
53
- assert len(config.mult) == len(config.down)
54
  decoder_layers = []
55
- output_channels = config.num_channels
56
 
57
- for m, d in zip(config.mult, config.down):
58
- input_channels = m * config.num_channels
59
  decoder_layers.append(ResidualBlock(input_channels, output_channels))
60
  output_channels = input_channels
61
  if d:
62
  decoder_layers.append(Upsample(input_channels))
63
  decoder_layers.reverse()
64
- decoder_layers.insert(0, nn.Conv2d(config.latent_dim, input_channels, kernel_size=3, stride=1, padding=1))
65
  decoder_layers.extend([
66
- nn.GroupNorm(num_groups=32, num_channels=config.num_channels),
67
  nn.SiLU(inplace=True),
68
- nn.Conv2d(config.num_channels, config.image_channels, kernel_size=3, stride=1, padding=1)
69
  ])
70
  self.decoder = nn.Sequential(*decoder_layers)
71
 
 
7
  import torch.nn.functional as F
8
 
9
 
 
 
 
 
 
 
 
 
 
10
  class FrameEncoder(nn.Module):
11
+ def __init__(self, config: dict) -> None:
12
  super().__init__()
13
 
14
+ assert len(config["mult"]) == len(config["down"])
15
+ encoder_layers = [nn.Conv2d(config["image_channels"], config["num_channels"], kernel_size=3, stride=1, padding=1)]
16
+ input_channels = config["num_channels"]
17
 
18
+ for m, d in zip(config["mult"], config["down"]):
19
+ output_channels = m * config["num_channels"]
20
  encoder_layers.append(ResidualBlock(input_channels, output_channels))
21
  input_channels = output_channels
22
  if d:
 
24
  encoder_layers.extend([
25
  nn.GroupNorm(num_groups=32, num_channels=input_channels),
26
  nn.SiLU(inplace=True),
27
+ nn.Conv2d(input_channels, config["latent_dim"], kernel_size=3, stride=1, padding=1)
28
  ])
29
  self.encoder = nn.Sequential(*encoder_layers)
30
 
 
38
 
39
 
40
  class FrameDecoder(nn.Module):
41
+ def __init__(self, config: dict) -> None:
42
  super().__init__()
43
 
44
+ assert len(config["mult"]) == len(config["down"])
45
  decoder_layers = []
46
+ output_channels = config["num_channels"]
47
 
48
+ for m, d in zip(config["mult"], config["down"]):
49
+ input_channels = m * config["num_channels"]
50
  decoder_layers.append(ResidualBlock(input_channels, output_channels))
51
  output_channels = input_channels
52
  if d:
53
  decoder_layers.append(Upsample(input_channels))
54
  decoder_layers.reverse()
55
+ decoder_layers.insert(0, nn.Conv2d(config["latent_dim"], input_channels, kernel_size=3, stride=1, padding=1))
56
  decoder_layers.extend([
57
+ nn.GroupNorm(num_groups=32, num_channels=config["num_channels"]),
58
  nn.SiLU(inplace=True),
59
+ nn.Conv2d(config["num_channels"], config["image_channels"], kernel_size=3, stride=1, padding=1)
60
  ])
61
  self.decoder = nn.Sequential(*decoder_layers)
62
 
delta-iris/src/models/tokenizer/__init__.py CHANGED
@@ -1 +0,0 @@
1
-
 
 
delta-iris/src/models/transformer.py CHANGED
@@ -11,58 +11,36 @@ import torch.nn as nn
11
 
12
  from .kv_caching import KeysValues, KVCache
13
 
14
-
15
- @dataclass
16
- class TransformerConfig:
17
-
18
- tokens_per_block: int
19
- max_blocks: int
20
-
21
- num_layers: int
22
- num_heads: int
23
- embed_dim: int
24
-
25
- attention: str
26
-
27
- embed_pdrop: float
28
- resid_pdrop: float
29
- attn_pdrop: float
30
-
31
- @property
32
- def max_tokens(self):
33
- return self.tokens_per_block * self.max_blocks
34
-
35
-
36
  class TransformerEncoder(nn.Module):
37
- def __init__(self, config: TransformerConfig) -> None:
38
  super().__init__()
39
  self.config = config
 
 
 
 
40
 
41
- self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim)
42
- self.emb_drop = nn.Dropout(config.embed_pdrop)
43
- self.ln = nn.LayerNorm(config.embed_dim)
44
-
45
- assert config.attention in ('causal', 'block_causal')
46
- k, m = config.tokens_per_block, config.max_blocks
47
  mask_sa = torch.tril(torch.ones(k * m, k * m))
48
- if config.attention == 'block_causal':
49
  mask_sa = torch.max(mask_sa, torch.block_diag(*[torch.ones(k, k) for _ in range(m)]))
50
  mask_sa = mask_sa.bool()
51
 
52
- self.blocks = nn.ModuleList([EncoderLayer(config, mask_sa) for _ in range(config.num_layers)])
53
  self.keys_values = None
54
 
55
  @property
56
  def num_blocks_left_in_kv_cache(self) -> float:
57
  assert self.keys_values is not None
58
- return (self.config.max_tokens - self.keys_values.size) / self.config.tokens_per_block
59
 
60
  def reset_kv_cache(self, n: int) -> None:
61
  device = self.ln.weight.device
62
- self.keys_values = KeysValues(n, self.config.max_tokens, self.config.embed_dim, self.config.num_layers, device)
63
 
64
  def forward(self, x: torch.FloatTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
65
- assert x.ndim == 3 and x.size(2) == self.config.embed_dim # (B, TK, E)
66
 
67
  prev_steps = self.keys_values.size if use_kv_cache else 0
68
  inputs = x + self.pos_emb(prev_steps + torch.arange(x.size(1), device=x.device))
@@ -76,7 +54,7 @@ class TransformerEncoder(nn.Module):
76
 
77
 
78
  class EncoderLayer(nn.Module):
79
- def __init__(self, config: TransformerConfig, mask_sa: torch.LongTensor) -> None:
80
  super().__init__()
81
  self.sa = SelfAttentionLayer(config, mask=mask_sa)
82
  self.mlp = MLPLayer(config)
@@ -86,14 +64,14 @@ class EncoderLayer(nn.Module):
86
 
87
 
88
  class MLPLayer(nn.Module):
89
- def __init__(self, config: TransformerConfig) -> None:
90
  super().__init__()
91
- self.ln = nn.LayerNorm(config.embed_dim)
92
  self.mlp = nn.Sequential(
93
- nn.Linear(config.embed_dim, 4 * config.embed_dim),
94
  nn.GELU(),
95
- nn.Linear(4 * config.embed_dim, config.embed_dim),
96
- nn.Dropout(config.resid_pdrop),
97
  )
98
 
99
  def forward(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
@@ -101,13 +79,13 @@ class MLPLayer(nn.Module):
101
 
102
 
103
  class SelfAttentionLayer(nn.Module):
104
- def __init__(self, config: TransformerConfig, mask: torch.BoolTensor) -> None:
105
  super().__init__()
106
  self.register_buffer('mask', mask)
107
- self.ln = nn.LayerNorm(config.embed_dim)
108
- self.query = nn.Linear(config.embed_dim, config.embed_dim)
109
- self.key = nn.Linear(config.embed_dim, config.embed_dim)
110
- self.value = nn.Linear(config.embed_dim, config.embed_dim)
111
  self.attention = Attention(config)
112
 
113
  def forward(self, inputs: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
@@ -134,13 +112,13 @@ class SelfAttentionLayer(nn.Module):
134
 
135
 
136
  class Attention(nn.Module):
137
- def __init__(self, config: TransformerConfig) -> None:
138
  super().__init__()
139
- assert config.embed_dim % config.num_heads == 0
140
- self.num_heads = config.num_heads
141
- self.attn_pdrop = config.attn_pdrop
142
- self.resid_drop = nn.Dropout(config.resid_pdrop)
143
- self.proj = nn.Linear(config.embed_dim, config.embed_dim)
144
 
145
  def forward(self, q: torch.FloatTensor, k: torch.FloatTensor, v: torch.FloatTensor, mask: torch.BoolTensor) -> torch.FloatTensor:
146
  assert mask.size(0) == q.size(1) and mask.size(1) == k.size(1)
 
11
 
12
  from .kv_caching import KeysValues, KVCache
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class TransformerEncoder(nn.Module):
15
+ def __init__(self, config: dict) -> None:
16
  super().__init__()
17
  self.config = config
18
+ self.config["max_tokens"] = config["tokens_per_block"] * config["max_blocks"]
19
+ self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"])
20
+ self.emb_drop = nn.Dropout(config["embed_pdrop"])
21
+ self.ln = nn.LayerNorm(config["embed_dim"])
22
 
23
+ assert config["attention"] in ('causal', 'block_causal')
24
+ k, m = config["tokens_per_block"], config["max_blocks"]
 
 
 
 
25
  mask_sa = torch.tril(torch.ones(k * m, k * m))
26
+ if config["attention"] == 'block_causal':
27
  mask_sa = torch.max(mask_sa, torch.block_diag(*[torch.ones(k, k) for _ in range(m)]))
28
  mask_sa = mask_sa.bool()
29
 
30
+ self.blocks = nn.ModuleList([EncoderLayer(config, mask_sa) for _ in range(config["num_layers"])])
31
  self.keys_values = None
32
 
33
  @property
34
  def num_blocks_left_in_kv_cache(self) -> float:
35
  assert self.keys_values is not None
36
+ return (self.config["max_tokens"] - self.keys_values.size) / self.config["tokens_per_block"]
37
 
38
  def reset_kv_cache(self, n: int) -> None:
39
  device = self.ln.weight.device
40
+ self.keys_values = KeysValues(n, self.config["max_tokens"], self.config["embed_dim"], self.config["num_layers"], device)
41
 
42
  def forward(self, x: torch.FloatTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
43
+ assert x.ndim == 3 and x.size(2) == self.config["embed_dim"] # (B, TK, E)
44
 
45
  prev_steps = self.keys_values.size if use_kv_cache else 0
46
  inputs = x + self.pos_emb(prev_steps + torch.arange(x.size(1), device=x.device))
 
54
 
55
 
56
  class EncoderLayer(nn.Module):
57
+ def __init__(self, config: dict, mask_sa: torch.LongTensor) -> None:
58
  super().__init__()
59
  self.sa = SelfAttentionLayer(config, mask=mask_sa)
60
  self.mlp = MLPLayer(config)
 
64
 
65
 
66
  class MLPLayer(nn.Module):
67
+ def __init__(self, config: dict) -> None:
68
  super().__init__()
69
+ self.ln = nn.LayerNorm(config["embed_dim"])
70
  self.mlp = nn.Sequential(
71
+ nn.Linear(config["embed_dim"], 4 * config["embed_dim"]),
72
  nn.GELU(),
73
+ nn.Linear(4 * config["embed_dim"], config["embed_dim"]),
74
+ nn.Dropout(config["resid_pdrop"]),
75
  )
76
 
77
  def forward(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
 
79
 
80
 
81
  class SelfAttentionLayer(nn.Module):
82
+ def __init__(self, config: dict, mask: torch.BoolTensor) -> None:
83
  super().__init__()
84
  self.register_buffer('mask', mask)
85
+ self.ln = nn.LayerNorm(config["embed_dim"])
86
+ self.query = nn.Linear(config["embed_dim"], config["embed_dim"])
87
+ self.key = nn.Linear(config["embed_dim"], config["embed_dim"])
88
+ self.value = nn.Linear(config["embed_dim"], config["embed_dim"])
89
  self.attention = Attention(config)
90
 
91
  def forward(self, inputs: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
 
112
 
113
 
114
  class Attention(nn.Module):
115
+ def __init__(self, config: dict) -> None:
116
  super().__init__()
117
+ assert config["embed_dim"] % config["num_heads"] == 0
118
+ self.num_heads = config["num_heads"]
119
+ self.attn_pdrop = config["attn_pdrop"]
120
+ self.resid_drop = nn.Dropout(config["resid_pdrop"])
121
+ self.proj = nn.Linear(config["embed_dim"], config["embed_dim"])
122
 
123
  def forward(self, q: torch.FloatTensor, k: torch.FloatTensor, v: torch.FloatTensor, mask: torch.BoolTensor) -> torch.FloatTensor:
124
  assert mask.size(0) == q.size(1) and mask.size(1) == k.size(1)
delta-iris/src/tokenizer.py CHANGED
@@ -6,48 +6,32 @@ from einops import rearrange
6
  import torch
7
  import torch.nn as nn
8
 
9
- from .models.convnet import FrameCnnConfig, FrameEncoder, FrameDecoder
10
  from .data import Batch
11
  from .models.tokenizer.quantizer import Quantizer, QuantizerOutput
12
  from .models.utils import init_weights, LossWithIntermediateLosses
13
 
14
-
15
- @dataclass
16
- class TokenizerConfig:
17
- image_channels: int
18
- image_size: int
19
- num_actions: int
20
- num_tokens: int
21
- decoder_act_channels: int
22
- codebook_size: int
23
- codebook_dim: int
24
- max_codebook_updates_with_revival: int
25
- encoder_config: FrameCnnConfig
26
- decoder_config: FrameCnnConfig
27
- frame_cnn_config: FrameCnnConfig
28
-
29
-
30
  class Tokenizer(nn.Module):
31
- def __init__(self, config: TokenizerConfig) -> None:
32
  super().__init__()
33
  self.config = config
34
 
35
- self.latent_res = config.image_size // 2 ** sum(config.encoder_config.down)
36
- self.tokens_grid_res = int(math.sqrt(config.num_tokens))
37
  self.token_res = self.latent_res // self.tokens_grid_res
38
 
39
- self.encoder_act_emb = nn.Embedding(config.num_actions, config.image_size ** 2)
40
- self.decoder_act_emb = nn.Embedding(config.num_actions, config.decoder_act_channels * self.latent_res ** 2)
41
 
42
  self.quantizer = Quantizer(
43
- config.codebook_size, config.codebook_dim,
44
- input_dim=config.encoder_config.latent_dim * self.token_res ** 2,
45
- max_codebook_updates_with_revival=config.max_codebook_updates_with_revival
46
  )
47
 
48
- self.encoder = FrameEncoder(config.encoder_config)
49
- self.decoder = FrameDecoder(config.decoder_config)
50
- self.frame_cnn = FrameEncoder(config.frame_cnn_config)
51
 
52
  self.apply(init_weights)
53
 
@@ -89,7 +73,7 @@ class Tokenizer(nn.Module):
89
 
90
  def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor:
91
  x1_emb = self.frame_cnn(x1)
92
- a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config.decoder_act_channels, h=x1_emb.size(3))
93
 
94
  decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2)
95
 
 
6
  import torch
7
  import torch.nn as nn
8
 
9
+ from .models.convnet import FrameEncoder, FrameDecoder
10
  from .data import Batch
11
  from .models.tokenizer.quantizer import Quantizer, QuantizerOutput
12
  from .models.utils import init_weights, LossWithIntermediateLosses
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class Tokenizer(nn.Module):
15
+ def __init__(self, config: dict) -> None:
16
  super().__init__()
17
  self.config = config
18
 
19
+ self.latent_res = config["image_size"] // 2 ** sum(config["encoder_config"]["down"])
20
+ self.tokens_grid_res = int(math.sqrt(config["num_tokens"]))
21
  self.token_res = self.latent_res // self.tokens_grid_res
22
 
23
+ self.encoder_act_emb = nn.Embedding(config["num_actions"], config["image_size"] ** 2)
24
+ self.decoder_act_emb = nn.Embedding(config["num_actions"], config["decoder_act_channels"] * self.latent_res ** 2)
25
 
26
  self.quantizer = Quantizer(
27
+ config["codebook_size"], config["codebook_dim"],
28
+ input_dim=config["encoder_config"]["latent_dim"] * self.token_res ** 2,
29
+ max_codebook_updates_with_revival=config["max_codebook_updates_with_revival"]
30
  )
31
 
32
+ self.encoder = FrameEncoder(config["encoder_config"])
33
+ self.decoder = FrameDecoder(config["decoder_config"])
34
+ self.frame_cnn = FrameEncoder(config["frame_cnn_config"])
35
 
36
  self.apply(init_weights)
37
 
 
73
 
74
  def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor:
75
  x1_emb = self.frame_cnn(x1)
76
+ a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config["decoder_act_channels"], h=x1_emb.size(3))
77
 
78
  decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2)
79
 
delta-iris/src/world_model.py CHANGED
@@ -6,77 +6,54 @@ import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
- from .models.convnet import FrameCnnConfig, FrameEncoder
10
  from .data import Batch
11
  from .models.slicer import Head
12
  from .tokenizer import Tokenizer
13
- from .models.transformer import TransformerEncoder, TransformerConfig
14
  from .models.utils import init_weights, LossWithIntermediateLosses, symlog, two_hot
15
 
16
-
17
- @dataclass
18
- class WorldModelOutput:
19
- output_sequence: torch.FloatTensor
20
- logits_latents: torch.FloatTensor
21
- logits_rewards: torch.FloatTensor
22
- logits_ends: torch.FloatTensor
23
-
24
-
25
- @dataclass
26
- class WorldModelConfig:
27
- latent_vocab_size: int
28
- num_actions: int
29
- image_channels: int
30
- image_size: int
31
- latents_weight: float
32
- rewards_weight: float
33
- ends_weight: float
34
- two_hot_rews: bool
35
- transformer_config: TransformerConfig
36
- frame_cnn_config: FrameCnnConfig
37
-
38
-
39
  class WorldModel(nn.Module):
40
- def __init__(self, config: WorldModelConfig) -> None:
41
  super().__init__()
42
  self.config = config
43
- self.transformer = TransformerEncoder(config.transformer_config)
44
 
45
- assert ((config.image_size // 2 ** sum(config.frame_cnn_config.down)) ** 2) * config.frame_cnn_config.latent_dim == config.transformer_config.embed_dim
46
- self.frame_cnn = nn.Sequential(FrameEncoder(config.frame_cnn_config), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config.transformer_config.embed_dim))
47
 
48
- self.act_emb = nn.Embedding(config.num_actions, config.transformer_config.embed_dim)
49
- self.latents_emb = nn.Embedding(config.latent_vocab_size, config.transformer_config.embed_dim)
50
 
51
- act_pattern = torch.zeros(config.transformer_config.tokens_per_block)
52
  act_pattern[1] = 1
53
- act_and_latents_but_last_pattern = torch.zeros(config.transformer_config.tokens_per_block)
54
  act_and_latents_but_last_pattern[1:-1] = 1
55
 
56
  self.head_latents = Head(
57
- max_blocks=config.transformer_config.max_blocks,
58
  block_mask=act_and_latents_but_last_pattern,
59
  head_module=nn.Sequential(
60
- nn.Linear(config.transformer_config.embed_dim, config.transformer_config.embed_dim), nn.ReLU(),
61
- nn.Linear(config.transformer_config.embed_dim, config.latent_vocab_size)
62
  )
63
  )
64
 
65
  self.head_rewards = Head(
66
- max_blocks=config.transformer_config.max_blocks,
67
  block_mask=act_pattern,
68
  head_module=nn.Sequential(
69
- nn.Linear(config.transformer_config.embed_dim, config.transformer_config.embed_dim), nn.ReLU(),
70
- nn.Linear(config.transformer_config.embed_dim, 255 if config.two_hot_rews else 3)
71
  )
72
  )
73
 
74
  self.head_ends = Head(
75
- max_blocks=config.transformer_config.max_blocks,
76
  block_mask=act_pattern,
77
  head_module=nn.Sequential(
78
- nn.Linear(config.transformer_config.embed_dim, config.transformer_config.embed_dim), nn.ReLU(),
79
- nn.Linear(config.transformer_config.embed_dim, 2)
80
  )
81
  )
82
 
@@ -85,7 +62,7 @@ class WorldModel(nn.Module):
85
  def __repr__(self) -> str:
86
  return "world_model"
87
 
88
- def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> WorldModelOutput:
89
  prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
90
  num_steps = sequence.size(1)
91
 
@@ -95,7 +72,12 @@ class WorldModel(nn.Module):
95
  logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
96
  logits_ends = self.head_ends(outputs, num_steps, prev_steps)
97
 
98
- return WorldModelOutput(outputs, logits_latents, logits_rewards, logits_ends)
 
 
 
 
 
99
 
100
  def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs) -> LossWithIntermediateLosses:
101
  assert torch.all(batch.ends.sum(dim=1) <= 1)
@@ -117,11 +99,11 @@ class WorldModel(nn.Module):
117
  labels_latents = latent_tokens[mask[:, :-1]].flatten()
118
  logits_latents = outputs.logits_latents[:, :-k][repeat(mask[:, :-1], 'b t -> b (t k)', k=k)]
119
  latent_acc = (logits_latents.max(dim=-1)[1] == labels_latents).float().mean()
120
- labels_rewards = two_hot(symlog(batch.rewards)) if self.config.two_hot_rews else (batch.rewards.sign() + 1).long()
121
 
122
- loss_latents = F.cross_entropy(logits_latents, target=labels_latents) * self.config.latents_weight
123
- loss_rewards = F.cross_entropy(outputs.logits_rewards[mask], target=labels_rewards[mask]) * self.config.rewards_weight
124
- loss_ends = F.cross_entropy(outputs.logits_ends[mask], target=batch.ends[mask]) * self.config.ends_weight
125
 
126
  return LossWithIntermediateLosses(loss_latents=loss_latents, loss_rewards=loss_rewards, loss_ends=loss_ends), {'latent_accuracy': latent_acc}
127
 
 
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
+ from .models.convnet import FrameEncoder
10
  from .data import Batch
11
  from .models.slicer import Head
12
  from .tokenizer import Tokenizer
13
+ from .models.transformer import TransformerEncoder
14
  from .models.utils import init_weights, LossWithIntermediateLosses, symlog, two_hot
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class WorldModel(nn.Module):
17
+ def __init__(self, config: dict) -> None:
18
  super().__init__()
19
  self.config = config
20
+ self.transformer = TransformerEncoder(config["transformer_config"])
21
 
22
+ assert ((config["image_size"] // 2 ** sum(config["frame_cnn_config"]["down"])) ** 2) * config["frame_cnn_config"]["latent_dim"] == config["transformer_config"]["embed_dim"]
23
+ self.frame_cnn = nn.Sequential(FrameEncoder(config["frame_cnn_config"]), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config["transformer_config"]["embed_dim"]))
24
 
25
+ self.act_emb = nn.Embedding(config["num_actions"], config["transformer_config"]["embed_dim"])
26
+ self.latents_emb = nn.Embedding(config["latent_vocab_size"], config["transformer_config"]["embed_dim"])
27
 
28
+ act_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
29
  act_pattern[1] = 1
30
+ act_and_latents_but_last_pattern = torch.zeros(config["transformer_config"]["tokens_per_block"])
31
  act_and_latents_but_last_pattern[1:-1] = 1
32
 
33
  self.head_latents = Head(
34
+ max_blocks=config["transformer_config"]["max_blocks"],
35
  block_mask=act_and_latents_but_last_pattern,
36
  head_module=nn.Sequential(
37
+ nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
38
+ nn.Linear(config["transformer_config"]["embed_dim"], config["latent_vocab_size"])
39
  )
40
  )
41
 
42
  self.head_rewards = Head(
43
+ max_blocks=config["transformer_config"]["max_blocks"],
44
  block_mask=act_pattern,
45
  head_module=nn.Sequential(
46
+ nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
47
+ nn.Linear(config["transformer_config"]["embed_dim"], 255 if config["two_hot_rews"] else 3)
48
  )
49
  )
50
 
51
  self.head_ends = Head(
52
+ max_blocks=config["transformer_config"]["max_blocks"],
53
  block_mask=act_pattern,
54
  head_module=nn.Sequential(
55
+ nn.Linear(config["transformer_config"]["embed_dim"], config["transformer_config"]["embed_dim"]), nn.ReLU(),
56
+ nn.Linear(config["transformer_config"]["embed_dim"], 2)
57
  )
58
  )
59
 
 
62
  def __repr__(self) -> str:
63
  return "world_model"
64
 
65
+ def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> dict:
66
  prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
67
  num_steps = sequence.size(1)
68
 
 
72
  logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
73
  logits_ends = self.head_ends(outputs, num_steps, prev_steps)
74
 
75
+ return {
76
+ "output_sequence": outputs,
77
+ "logits_latents": logits_latents,
78
+ "logits_rewards": logits_rewards,
79
+ "logits_ends": logits_ends
80
+ }
81
 
82
  def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs) -> LossWithIntermediateLosses:
83
  assert torch.all(batch.ends.sum(dim=1) <= 1)
 
99
  labels_latents = latent_tokens[mask[:, :-1]].flatten()
100
  logits_latents = outputs.logits_latents[:, :-k][repeat(mask[:, :-1], 'b t -> b (t k)', k=k)]
101
  latent_acc = (logits_latents.max(dim=-1)[1] == labels_latents).float().mean()
102
+ labels_rewards = two_hot(symlog(batch.rewards)) if self.config["two_hot_rews"] else (batch.rewards.sign() + 1).long()
103
 
104
+ loss_latents = F.cross_entropy(logits_latents, target=labels_latents) * self.config["latents_weight"]
105
+ loss_rewards = F.cross_entropy(outputs.logits_rewards[mask], target=labels_rewards[mask]) * self.config["rewards_weight"]
106
+ loss_ends = F.cross_entropy(outputs.logits_ends[mask], target=batch.ends[mask]) * self.config["ends_weight"]
107
 
108
  return LossWithIntermediateLosses(loss_latents=loss_latents, loss_rewards=loss_rewards, loss_ends=loss_ends), {'latent_accuracy': latent_acc}
109