schrum2 commited on
Commit
0ef1dd5
·
1 Parent(s): 179f3e8

Attempting new file structure that requires fewer drastic changes

Browse files
model_index.json CHANGED
@@ -1,20 +1,11 @@
1
  {
2
  "_class_name": "TextConditionalDDPMPipeline",
3
  "_diffusers_version": "0.32.2",
4
- "scheduler": [
5
- "diffusers",
6
- "DDPMScheduler"
7
- ],
8
- "text_encoder": [
9
- "models.text_model",
10
- "TransformerModel"
11
- ],
12
- "tokenizer": [
13
- "tokenizer",
14
- "Tokenizer"
15
- ],
16
- "unet": [
17
- "diffusers",
18
- "UNet2DConditionModel"
19
- ]
20
- }
 
1
  {
2
  "_class_name": "TextConditionalDDPMPipeline",
3
  "_diffusers_version": "0.32.2",
4
+ "custom_pipeline": "models/text_diffusion_pipeline.py",
5
+ "components": {
6
+ "unet": { "type": "UNet2DConditionModel", "subfolder": "unet" },
7
+ "text_encoder": { "type": "models.text_model.TransformerModel", "subfolder": "text_encoder" },
8
+ "tokenizer": { "type": "Tokenizer", "file": "tokenizer.py" },
9
+ "scheduler": { "type": "DDPMScheduler", "subfolder": "scheduler" }
10
+ }
11
+ }
 
 
 
 
 
 
 
 
 
models/text_diffusion_pipeline.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from typing import NamedTuple, Optional
4
+ import os
5
+ from diffusers import DDPMPipeline, UNet2DConditionModel, DDPMScheduler
6
+ import json
7
+ # Running the main at the end of this requires messing with this import
8
+ from models.text_model import TransformerModel
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from transformers import AutoTokenizer, AutoModel
12
+ import util.common_settings as common_settings
13
+ import models.sentence_transformers_helper as st_helper
14
+ import models.text_model as text_model
15
+ #from models.general_training_helper import get_scene_from_embeddings
16
+
17
+ class PipelineOutput(NamedTuple):
18
+ images: torch.Tensor
19
+
20
+
21
+
22
+ # Create a custom pipeline for text-conditional generation
23
+ class TextConditionalDDPMPipeline(DDPMPipeline):
24
+ def __init__(self, unet, scheduler, text_encoder=None, tokenizer=None, supports_pretrained_split=False, block_embeddings=None):
25
+ super().__init__(unet=unet, scheduler=scheduler)
26
+ self.text_encoder = text_encoder
27
+ self.tokenizer = tokenizer
28
+ self.supports_negative_prompt = hasattr(unet, 'negative_prompt_support') and unet.negative_prompt_support
29
+ self.supports_pretrained_split = supports_pretrained_split
30
+ self.block_embeddings = block_embeddings
31
+
32
+ if self.tokenizer is None and self.text_encoder is not None:
33
+ # Use the tokenizer from the text encoder if not provided
34
+ self.tokenizer = self.text_encoder.tokenizer
35
+
36
+ # Register the text_encoder so that .to(), .cpu(), .cuda(), etc. work correctly
37
+ self.register_modules(
38
+ unet=unet,
39
+ scheduler=scheduler,
40
+ text_encoder=self.text_encoder,
41
+ tokenizer=self.tokenizer,
42
+ )
43
+
44
+ # Override the to() method to ensure text_encoder is moved to the correct device
45
+ def to(self, device=None, dtype=None):
46
+ # Call the parent's to() method first
47
+ pipeline = super().to(device, dtype)
48
+
49
+ # Additionally move the text_encoder to the device
50
+ if self.text_encoder is not None:
51
+ self.text_encoder.to(device)
52
+
53
+ return pipeline
54
+
55
+ def save_pretrained(self, save_directory):
56
+ os.makedirs(save_directory, exist_ok=True)
57
+ super().save_pretrained(save_directory) # saves UNet and scheduler
58
+
59
+ # Save block_embeddings tensor if it exists
60
+ if self.block_embeddings is not None:
61
+ torch.save(self.block_embeddings, os.path.join(save_directory, "block_embeddings.pt"))
62
+
63
+ # Save supports_negative_prompt and supports_pretrained_split flags
64
+ with open(os.path.join(save_directory, "pipeline_config.json"), "w") as f:
65
+ json.dump({
66
+ "supports_negative_prompt": self.supports_negative_prompt,
67
+ "supports_pretrained_split": self.supports_pretrained_split,
68
+ "text_encoder_type": type(self.text_encoder).__name__
69
+ }, f)
70
+
71
+
72
+ #Text encoder/tokenizer saving is different depending on if we're using a larger pretrained model
73
+ if isinstance(self.text_encoder, TransformerModel):
74
+ # Save custom text encoder
75
+ if self.text_encoder is not None:
76
+ self.text_encoder.save_pretrained(os.path.join(save_directory, "text_encoder"))
77
+ else:
78
+ #Save pretrained tokenizer by name, so we can load from huggingface instead of saving a giant local model
79
+ text_encoder_info = {
80
+ "text_encoder_name": self.text_encoder.config.name_or_path,
81
+ "tokenizer_name": self.tokenizer.name_or_path,
82
+ }
83
+
84
+ text_encoder_directory = os.path.join(save_directory, "text_encoder")
85
+ os.makedirs(text_encoder_directory, exist_ok=True)
86
+
87
+ with open(os.path.join(text_encoder_directory, "loading_info.json"), "w") as f:
88
+ json.dump(text_encoder_info, f)
89
+
90
+
91
+
92
+ @classmethod
93
+ def from_pretrained(cls, pretrained_model_path, **kwargs):
94
+ #from diffusers.utils import load_config, load_state_dict
95
+ # Load model_index.json
96
+ #model_index = load_config(pretrained_model_path)
97
+
98
+ # Load components manually
99
+ unet_path = os.path.join(pretrained_model_path, "unet")
100
+ unet = UNet2DConditionModel.from_pretrained(unet_path)
101
+
102
+ scheduler_path = os.path.join(pretrained_model_path, "scheduler")
103
+ # Have heard that DDIMScheduler might be faster for inference, though not necessarily better
104
+ scheduler = DDPMScheduler.from_pretrained(scheduler_path)
105
+
106
+ tokenizer = None
107
+ text_encoder_path = os.path.join(pretrained_model_path, "text_encoder")
108
+
109
+ if os.path.exists(text_encoder_path):
110
+ #Test for the new saving system, where we save a simple config file
111
+ if os.path.exists(os.path.join(text_encoder_path, "loading_info.json")):
112
+ with open(os.path.join(text_encoder_path, "loading_info.json"), "r") as f:
113
+ encoder_config = json.load(f)
114
+
115
+ text_encoder = AutoModel.from_pretrained(encoder_config['text_encoder_name'], trust_remote_code=True)
116
+ tokenizer = AutoTokenizer.from_pretrained(encoder_config['tokenizer_name'])
117
+
118
+ #Legacy loading system, loads models directly if the whole thing is saved in the directory
119
+ else:
120
+ try:
121
+ text_encoder = AutoModel.from_pretrained(text_encoder_path, local_files_only=True, trust_remote_code=True)
122
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_path, local_files_only=True)
123
+ except (ValueError, KeyError):
124
+ text_encoder = TransformerModel.from_pretrained(text_encoder_path)
125
+ tokenizer = text_encoder.tokenizer
126
+ else:
127
+ text_encoder = None
128
+
129
+ # Instantiate your pipeline
130
+ pipeline = cls(
131
+ unet=unet,
132
+ scheduler=scheduler,
133
+ text_encoder=text_encoder,
134
+ tokenizer=tokenizer,
135
+ **kwargs,
136
+ )
137
+
138
+ #Loads block embeddings if present
139
+ block_embeds_path = os.path.join(pretrained_model_path, "block_embeddings.pt")
140
+ if os.path.exists(block_embeds_path):
141
+ pipeline.block_embeddings = torch.load(block_embeds_path, map_location="cpu")
142
+ else:
143
+ pipeline.block_embeddings = None
144
+
145
+
146
+ # Load supports_negative_prompt flag if present
147
+ config_path = os.path.join(pretrained_model_path, "pipeline_config.json")
148
+ if os.path.exists(config_path):
149
+ with open(config_path, "r") as f:
150
+ config = json.load(f)
151
+ pipeline.supports_negative_prompt = config.get("supports_negative_prompt", False)
152
+ pipeline.supports_pretrained_split = config.get("supports_pretrained_split", False)
153
+ return pipeline
154
+
155
+ # --- Handle batching for captions ---
156
+ def _prepare_text_batch(self, text: Optional[str | list[str]], batch_size: int, name: str) -> Optional[list[str]]:
157
+ if text is None:
158
+ return None
159
+ if isinstance(text, str):
160
+ return [text] * batch_size
161
+ if isinstance(text, list):
162
+ if len(text) == 1:
163
+ return text * batch_size
164
+ if len(text) != batch_size:
165
+ raise ValueError(f"{name} list length {len(text)} does not match batch_size {batch_size}")
166
+ return text
167
+ raise ValueError(f"{name} must be a string or list of strings")
168
+
169
+ def _prepare_initial_sample(self,
170
+ raw_latent_sample: Optional[torch.Tensor],
171
+ input_scene: Optional[torch.Tensor],
172
+ batch_size: int, height: int, width: int,
173
+ generator: Optional[torch.Generator]) -> torch.Tensor:
174
+ """Prepare the initial sample for diffusion."""
175
+
176
+ sample_shape = (batch_size, self.unet.config.in_channels, height, width)
177
+
178
+ if raw_latent_sample is not None:
179
+ if input_scene is not None:
180
+ raise ValueError("Cannot provide both raw_latent_sample and input_scene")
181
+ sample = raw_latent_sample.to(self.device)
182
+ if sample.shape[1] != sample_shape[1]:
183
+ raise ValueError(f"Wrong number of channels in raw_latent_sample: Expected {self.unet.config.in_channels} but got {sample.shape[1]}")
184
+ if sample.shape[0] == 1 and batch_size > 1:
185
+ sample = sample.repeat(batch_size, 1, 1, 1)
186
+ elif sample.shape[0] != batch_size:
187
+ raise ValueError(f"raw_latent_sample batch size {sample.shape[0]} does not match batch_size {batch_size}")
188
+ elif input_scene is not None:
189
+ # input_scene can be (H, W) or (batch_size, H, W)
190
+ scene_tensor = torch.tensor(input_scene, dtype=torch.long, device=self.device)
191
+ if scene_tensor.dim() == 2:
192
+ # (H, W) -> repeat for batch
193
+ scene_tensor = scene_tensor.unsqueeze(0).repeat(batch_size, 1, 1)
194
+ elif scene_tensor.shape[0] == 1 and batch_size > 1:
195
+ scene_tensor = scene_tensor.repeat(batch_size, 1, 1)
196
+ elif scene_tensor.shape[0] != batch_size:
197
+ raise ValueError(f"input_scene batch size {scene_tensor.shape[0]} does not match batch_size {batch_size}")
198
+ # One-hot encode: (batch, H, W, C)
199
+ one_hot = F.one_hot(scene_tensor, num_classes=self.unet.config.in_channels).float()
200
+ # (batch, H, W, C) -> (batch, C, H, W)
201
+ sample = one_hot.permute(0, 3, 1, 2)
202
+ else:
203
+ # Start from random noise
204
+ sample = torch.randn(sample_shape, generator=generator, device=self.device)
205
+
206
+ return sample
207
+
208
+ def __call__(
209
+ self,
210
+ caption: Optional[str | list[str]] = None,
211
+ negative_prompt: Optional[str | list[str]] = None,
212
+ generator: Optional[torch.Generator] = None,
213
+ num_inference_steps: int = common_settings.NUM_INFERENCE_STEPS,
214
+ guidance_scale: float = common_settings.GUIDANCE_SCALE,
215
+ height: int = common_settings.MARIO_HEIGHT,
216
+ width: int = common_settings.MARIO_WIDTH,
217
+ raw_latent_sample: Optional[torch.FloatTensor] = None,
218
+ input_scene: Optional[torch.Tensor] = None,
219
+ output_type: str = "tensor",
220
+ batch_size: int = 1,
221
+ show_progress_bar: bool = True,
222
+ ) -> PipelineOutput:
223
+ """Generate a batch of images based on text input using the diffusion model.
224
+
225
+ Args:
226
+ caption: Text description(s) of the desired output. Can be a string or list of strings.
227
+ negative_prompt: Text description(s) of what should not appear in the output. String or list.
228
+ generator: Random number generator for reproducibility.
229
+ num_inference_steps: Number of denoising steps (more = higher quality, slower).
230
+ guidance_scale: How strongly the generation follows the text prompt (higher = stronger).
231
+ height: Height of generated image in tiles.
232
+ width: Width of generated image in tiles.
233
+ raw_latent_sample: Optional starting point for diffusion instead of random noise.
234
+ Must have correct number of channels matching the UNet.
235
+ input_scene: Optional 2D or 3D int tensor where each value corresponds to a tile type.
236
+ Will be converted to one-hot encoding as starting point.
237
+ output_type: Currently only "tensor" is supported.
238
+ batch_size: Number of samples to generate in parallel.
239
+
240
+ Returns:
241
+ PipelineOutput containing the generated image tensor (batch_size, ...).
242
+ """
243
+
244
+ # I would like to simplify the code to this, but the AI suggestion didn't work, and
245
+ # I did not feel good just pasting it all in. Will need to tackle it bit by bit.
246
+
247
+ # if caption is not None and self.text_encoder is None:
248
+ # raise ValueError("Text encoder required for conditional generation")
249
+
250
+ # self.unet.eval()
251
+ # if self.text_encoder is not None:
252
+ # self.text_encoder.to(self.device)
253
+ # self.text_encoder.eval()
254
+ #
255
+ # with torch.no_grad():
256
+ # # Process text inputs
257
+ # captions = self.prepare_text_batch(caption, batch_size, "caption")
258
+ # negatives = self.prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
259
+
260
+ # # Get embeddings
261
+ # text_embeddings = self.prepare_embeddings(captions, negatives, batch_size)
262
+ #
263
+ # # Set up initial latent state
264
+ # sample = self.prepare_initial_sample(raw_latent_sample, input_scene,
265
+ # batch_size, height, width, generator)
266
+
267
+ # # Run diffusion process
268
+ # sample = self.run_diffusion(sample, text_embeddings, num_inference_steps,
269
+ # guidance_scale, generator, show_progress_bar,
270
+ # has_caption=caption is not None,
271
+ # has_negative=negative_prompt is not None)
272
+
273
+ # # Format output
274
+ # if output_type == "tensor":
275
+ # sample = F.softmax(sample, dim=1)
276
+ # else:
277
+ # raise ValueError(f"Unsupported output type: {output_type}")
278
+
279
+ # return PipelineOutput(images=sample)
280
+
281
+ # Validate text encoder if we need it
282
+ if caption is not None and self.text_encoder is None:
283
+ raise ValueError("Text encoder is required for conditional generation")
284
+
285
+ self.unet.eval()
286
+ if self.text_encoder is not None:
287
+ self.text_encoder.to(self.device)
288
+ self.text_encoder.eval()
289
+
290
+ with torch.no_grad():
291
+ captions = self._prepare_text_batch(caption, batch_size, "caption")
292
+ negatives = self._prepare_text_batch(negative_prompt, batch_size, "negative_prompt")
293
+
294
+ # --- Prepare text embeddings ---
295
+ if(isinstance(self.text_encoder, TransformerModel)):
296
+ text_embeddings = text_model.get_embeddings(batch_size=batch_size,
297
+ tokenizer=self.text_encoder.tokenizer,
298
+ text_encoder=self.text_encoder,
299
+ captions=captions,
300
+ neg_captions=negatives,
301
+ device=self.device)
302
+ else: #Case for the pre-trained text encoder
303
+ if(self.supports_pretrained_split): #If we have a split flag incorporated
304
+ text_embeddings = st_helper.get_embeddings_split(batch_size = batch_size,
305
+ tokenizer=self.tokenizer,
306
+ model=self.text_encoder,
307
+ captions=captions,
308
+ neg_captions=negatives,
309
+ device=self.device)
310
+ else:
311
+ text_embeddings = st_helper.get_embeddings(batch_size = batch_size,
312
+ tokenizer=self.tokenizer,
313
+ model=self.text_encoder,
314
+ captions=captions,
315
+ neg_captions=negatives,
316
+ device=self.device)
317
+
318
+
319
+ # --- Set up initial latent state ---
320
+ sample = self._prepare_initial_sample(raw_latent_sample, input_scene,
321
+ batch_size, height, width, generator)
322
+
323
+ # --- Set up diffusion process ---
324
+ self.scheduler.set_timesteps(num_inference_steps)
325
+
326
+ # Denoising loop
327
+ iterator = self.progress_bar(self.scheduler.timesteps) if show_progress_bar else self.scheduler.timesteps
328
+ for t in iterator:
329
+ # Handle conditional generation
330
+ if captions is not None:
331
+ if negatives is not None:
332
+ # Three copies for negative prompt guidance
333
+ model_input = torch.cat([sample, sample, sample], dim=0)
334
+ else:
335
+ # Two copies for standard classifier-free guidance
336
+ model_input = torch.cat([sample, sample], dim=0)
337
+ else:
338
+ model_input = sample
339
+
340
+ # Predict noise residual
341
+ model_kwargs = {"encoder_hidden_states": text_embeddings}
342
+ noise_pred = self.unet(model_input, t, **model_kwargs).sample
343
+
344
+ # Apply guidance
345
+ if captions is not None:
346
+ if negatives is not None:
347
+ # Split predictions for negative, unconditional, and text-conditional
348
+ noise_pred_neg, noise_pred_uncond, noise_pred_text = noise_pred.chunk(3)
349
+ noise_pred_guided = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
350
+ noise_pred = noise_pred_guided - guidance_scale * (noise_pred_neg - noise_pred_uncond)
351
+ else:
352
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
353
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
354
+
355
+ # Compute previous sample: x_{t-1} = scheduler(x_t, noise_pred)
356
+ sample = self.scheduler.step(noise_pred, t, sample, generator=generator).prev_sample
357
+
358
+ # Convert to output format
359
+ if output_type == "tensor":
360
+ if self.block_embeddings is not None:
361
+ sample = get_scene_from_embeddings(sample, self.block_embeddings)
362
+ else:
363
+ # Apply softmax to get probabilities for each tile type
364
+ sample = F.softmax(sample, dim=1)
365
+ sample = sample.detach().cpu()
366
+ else:
367
+ raise ValueError(f"Unsupported output type: {output_type}")
368
+
369
+ return PipelineOutput(images=sample)
370
+
371
+ def print_unet_architecture(self):
372
+ """Prints the architecture of the UNet model."""
373
+ print(self.unet)
374
+
375
+ def print_text_encoder_architecture(self):
376
+ """Prints the architecture of the text encoder model, if it exists."""
377
+ if self.text_encoder is not None:
378
+ print(self.text_encoder)
379
+ else:
380
+ print("No text encoder is set.")
381
+
382
+ def save_unet_architecture_pdf(self, height, width, filename="unet_architecture", batch_size=1, device=None):
383
+ """
384
+ Have to separately install torchview for this to work
385
+
386
+ Saves a visualization of the UNet architecture as a PDF using torchview.
387
+ Args:
388
+ height: Height of the dummy input.
389
+ width: Width of the dummy input.
390
+ filename: Output PDF filename.
391
+ batch_size: Batch size for dummy input.
392
+ device: Device to run the dummy input on (defaults to pipeline device).
393
+ """
394
+ from torchview import draw_graph
395
+ import graphviz
396
+
397
+ if device is None:
398
+ device = self.device if hasattr(self, 'device') else 'cpu'
399
+ in_channels = self.unet.config.in_channels if hasattr(self.unet, 'config') else 1
400
+ sample_shape = tuple([batch_size, in_channels, height, width])
401
+
402
+ dummy_x = torch.randn(size=sample_shape, device=device)
403
+ dummy_t = torch.tensor([0] * batch_size, dtype=torch.long, device=device)
404
+
405
+ # Prepare dummy text embedding (match what your UNet expects)
406
+ if hasattr(self.unet, 'config') and hasattr(self.unet.config, 'cross_attention_dim'):
407
+ cross_attention_dim = self.unet.config.cross_attention_dim
408
+ else:
409
+ cross_attention_dim = 128 # fallback
410
+ encoder_hidden_states = torch.randn(batch_size, 1, cross_attention_dim, device=device)
411
+
412
+ self.unet.eval()
413
+ inputs = (dummy_x, dummy_t, encoder_hidden_states)
414
+ #self.unet.down_blocks = self.unet.down_blocks[:2]
415
+
416
+ graph = draw_graph(
417
+ model=self.unet,
418
+ input_data=inputs,
419
+ expand_nested=False,
420
+ #enable_output_shape=True,
421
+ #roll_out="nested",
422
+ depth=1
423
+ )
424
+ #graph.visual_graph.engine = "neato"
425
+ graph.visual_graph.attr(#rankdir="LR",
426
+ nodesep="0.1", # decrease space between nodes in the same rank (default ~0.25)
427
+ ranksep="0.2", # decrease space between ranks (default ~0.5)
428
+ concentrate="true" # merge edges between nodes in the same rank
429
+ )
430
+ graph.visual_graph.node_attr.update(
431
+ shape="rectangle",
432
+ width="1.5", # narrow width
433
+ height="0.5" # taller height to make vertical rectangles
434
+ #fixedsize="true"
435
+ )
436
+
437
+ graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
438
+ graph.visual_graph.save('unet_architecture.dot')
439
+
440
+ # Save the graph to a PDF file
441
+ print(f"UNet architecture saved to {filename}")
442
+
models/text_model.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from xml.parsers.expat import model
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+ import os
7
+ import json
8
+ from safetensors.torch import save_file, load_file
9
+ from tokenizer import Tokenizer
10
+
11
+ def get_embeddings(batch_size, tokenizer, text_encoder, captions=None, neg_captions=None, device='cpu'):
12
+ max_length = text_encoder.max_seq_length
13
+ empty_ids = encode_token_captions([""] * batch_size, tokenizer, max_length, device=device)
14
+ embeddings = text_encoder.get_embeddings(empty_ids)
15
+
16
+ if(captions is not None):
17
+ caption_ids = encode_token_captions(captions, tokenizer, max_length, device=device)
18
+ caption_embeddings = text_encoder.get_embeddings(caption_ids)
19
+ embeddings = torch.cat((embeddings, caption_embeddings), dim=0)
20
+
21
+ if(neg_captions is not None):
22
+ neg_ids = encode_token_captions(neg_captions, tokenizer, max_length, device=device)
23
+ neg_embeddings = text_encoder.get_embeddings(neg_ids)
24
+ embeddings = torch.cat((neg_embeddings, embeddings), dim=0)
25
+
26
+ return embeddings.to(device)
27
+
28
+ def encode_token_captions(captions, tokenizer, max_length, device='cpu'):
29
+ caption_ids = []
30
+ for caption in captions:
31
+ tokens = tokenizer.encode(caption)
32
+ caption_tokens = tokenizer.pad_sequence(tokens, max_length)
33
+ caption_ids.append(torch.tensor(caption_tokens, dtype=torch.long).unsqueeze(0))
34
+ return torch.cat(caption_ids, dim=0).to(device)
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+ # Transformer model for MLM training
45
+
46
+ class TransformerModel(nn.Module):
47
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, tokenizer=None, num_heads=8, num_layers=4, max_seq_length=100):
48
+ super().__init__()
49
+ self.embedding_dim = embedding_dim
50
+ self.vocab_size = vocab_size
51
+ self.hidden_dim = hidden_dim
52
+ self.num_heads = num_heads
53
+ self.num_layers = num_layers
54
+ self.max_seq_length = max_seq_length
55
+
56
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
57
+ self.positional_encoding = self.create_positional_encoding(max_seq_length, embedding_dim)
58
+
59
+ encoder_layers = nn.TransformerEncoderLayer(
60
+ d_model=embedding_dim,
61
+ nhead=num_heads,
62
+ dim_feedforward=hidden_dim,
63
+ batch_first=True
64
+ )
65
+ self.transformer = nn.TransformerEncoder(encoder_layers, num_layers)
66
+ self.fc = nn.Linear(embedding_dim, vocab_size)
67
+
68
+ self.tokenizer = tokenizer
69
+
70
+ def create_positional_encoding(self, max_seq_length, embedding_dim):
71
+ # The implementation uses a sinusoidal positional encoding, which creates a unique pattern for each position in the sequence.
72
+ # The frequencies create unique values, the sin/cos bounds values
73
+ position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
74
+ # Creates a set of divisors that create different frequencies
75
+ div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
76
+ pe = torch.zeros(max_seq_length, embedding_dim)
77
+ # Even dimensions use sin, odd dimensions use cos
78
+ pe[:, 0::2] = torch.sin(position * div_term)
79
+ pe[:, 1::2] = torch.cos(position * div_term)
80
+ return pe.unsqueeze(0)
81
+
82
+ def get_embeddings(self, x):
83
+ """ This gets the actual latent embedding vectors """
84
+ # Ensure positional encoding is on the same device as input
85
+ pe = self.positional_encoding[:, :x.size(1), :].to(x.device)
86
+ # Embed input and add positional encoding
87
+ embedded = self.embedding(x) + pe
88
+ return self.transformer(embedded)
89
+
90
+ def forward(self, x):
91
+ """ This gets the token within the vocabulary """
92
+ transformer_out = self.get_embeddings(x)
93
+ # Project to vocabulary size
94
+ return self.fc(transformer_out)
95
+
96
+ def save_pretrained(self, save_directory):
97
+ os.makedirs(save_directory, exist_ok=True)
98
+
99
+ config = {
100
+ "vocab_size": self.vocab_size,
101
+ "embedding_dim": self.embedding_dim,
102
+ "hidden_dim": self.hidden_dim,
103
+ "num_heads": self.num_heads,
104
+ "num_layers": self.num_layers,
105
+ "max_seq_length": self.max_seq_length,
106
+ }
107
+ with open(os.path.join(save_directory, "config.json"), "w") as f:
108
+ json.dump(config, f)
109
+
110
+ # Save model weights
111
+ save_file(self.state_dict(), os.path.join(save_directory, "model.safetensors"))
112
+
113
+ # Save tokenizer if present
114
+ if self.tokenizer is not None:
115
+ self.tokenizer.save(os.path.join(save_directory, "tokenizer.pkl"))
116
+
117
+ @classmethod
118
+ def from_pretrained(cls, load_directory):
119
+ with open(os.path.join(load_directory, "config.json")) as f:
120
+ config = json.load(f)
121
+
122
+ model = cls(**config)
123
+
124
+ # Load weights
125
+ state_dict = load_file(os.path.join(load_directory, "model.safetensors"))
126
+ model.load_state_dict(state_dict)
127
+
128
+ # Load tokenizer if available
129
+ tokenizer_path = os.path.join(load_directory, "tokenizer.pkl")
130
+ if os.path.exists(tokenizer_path):
131
+ tokenizer = Tokenizer()
132
+ tokenizer.load(tokenizer_path)
133
+ model.tokenizer = tokenizer
134
+
135
+ return model
136
+
137
+ def print_architecture(self, inputs=None):
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument("--model_path", type=str, required=True, help="Path to trained transformer model")
140
+ parser.add_argument("--json", type=str, default="SMB1_LevelsAndCaptions-regular-test.json", help="Path to dataset json file")
141
+ parser.add_argument("--num_samples", type=int, default=10, help="Number of captions to evaluate")
142
+ parser.add_argument("--mask_prob", type=float, default=0.15, help="Probability of masking each token")
143
+
144
+ parser.add_argument("--compare_checkpoints", action="store_true", default=False, help="Run comparison across all model checkpoints")
145
+ args = parser.parse_args()
146
+
147
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
+ model = TransformerModel.from_pretrained(args.model_path).to(device)
149
+ print(f"Loaded model from {args.model_path}")
150
+
151
+ import os
152
+ import re
153
+ import json
154
+ import matplotlib.pyplot as plt
155
+ from torchview import draw_graph
156
+ import graphviz
157
+
158
+ graph = draw_graph(
159
+ model=model,
160
+ input_data=inputs,
161
+ expand_nested=False,
162
+ #enable_output_shape=True,
163
+ #roll_out="nested",
164
+ depth=1
165
+ )
166
+
167
+ # Save plot
168
+ filename = 'mlm_architecture'
169
+ graph.visual_graph.render(filename, format='pdf', cleanup=False) # Cleanup removes intermediate files
170
+ #graph.visual_graph.save('unet_architecture.dot')
171
+
172
+ def save_architecture_pdf(self, filename="transformer_architecture.pdf", input_length=32):
173
+ """Save a visualization of the model architecture as a PDF using torchview."""
174
+ try:
175
+ from torchview import draw_graph
176
+ except ImportError:
177
+ raise ImportError("torchview is required for model visualization. Install with 'pip install torchview'.")
178
+ import torch
179
+ import os
180
+ # Create a dummy input of the correct type for the model
181
+ captions = ["full floor. two coins. one pipe.", "floor with two gaps. one cannon. many enemies."]
182
+ tensor = encode_token_captions(captions, self.tokenizer, self.max_seq_length, device=next(self.parameters()).device)
183
+ input_length = tensor.size(1) if tensor.dim() > 1 else self.max_seq_length
184
+
185
+ num_tokens_list = [len(self.tokenizer.encode(c)) for c in captions]
186
+ input_length = max(num_tokens_list) if num_tokens_list else input_length
187
+ dummy_input = torch.zeros((1, input_length), dtype=torch.long, device=next(self.parameters()).device)
188
+
189
+ # Draw the graph and save as PNG
190
+ graph = draw_graph(self, input_data=dummy_input, expand_nested=True, save_graph=True, filename=filename.replace('.pdf',''), directory=".", depth=2)
191
+ png_file = filename.replace('.pdf', '.png')
192
+ # Convert PNG to PDF
193
+ if os.path.exists(png_file):
194
+ try:
195
+ from PIL import Image
196
+ im = Image.open(png_file)
197
+ im.save(filename, "PDF", resolution=100.0)
198
+ print(f"Saved architecture PDF to {filename}")
199
+ # Optionally, remove the PNG file
200
+ os.remove(png_file)
201
+ except ImportError:
202
+ print(f"PIL not installed. Architecture saved as PNG: {png_file}")
203
+ except Exception as e:
204
+ print(f"Could not convert PNG to PDF: {e}")
205
+ else:
206
+ print(f"Could not find PNG file to convert: {png_file}")
tokenizer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from collections import Counter
4
+ import pickle
5
+ import argparse
6
+
7
+ class Tokenizer:
8
+ def __init__(self):
9
+ self.special_tokens = ["[PAD]", "[MASK]"]
10
+ self.vocab = {}
11
+ self.token_to_id = {}
12
+ self.id_to_token = {}
13
+
14
+ def tokenize(self, text):
15
+ # Match words, numbers, periods, and commas as separate tokens
16
+ tokens = re.findall(r'\w+|[.,]|\[mask\]|\[pad\]', text.lower())
17
+ # Restore MASK and PAD to all caps
18
+ modified_list = []
19
+ for s in tokens:
20
+ modified_s = s.replace("[mask]", "[MASK]").replace("[pad]", "[PAD]")
21
+ modified_list.append(modified_s)
22
+ return modified_list
23
+
24
+ def pad_sequence(self, tokens, length):
25
+ """Pads tokenized sequences to length with a padding token (assumed to be '[PAD]')."""
26
+ if len(tokens) > length:
27
+ raise ValueError(f"Token sequence length {len(tokens)} exceeds specified length {length}.")
28
+
29
+ pad_token = self.token_to_id["[PAD]"]
30
+ return tokens + [pad_token] * (length - len(tokens))
31
+
32
+ def build_vocab(self, dataset_path, min_freq=1):
33
+ token_counter = Counter()
34
+
35
+ with open(dataset_path, 'r') as f:
36
+ data = json.load(f)
37
+ for entry in data:
38
+ caption = entry['caption']
39
+ tokens = self.tokenize(caption)
40
+ token_counter.update(tokens)
41
+
42
+ # Keep tokens that meet the min frequency
43
+ tokens = [tok for tok, count in token_counter.items() if count >= min_freq]
44
+
45
+ # Ensure special tokens are always included
46
+ all_tokens = self.special_tokens + sorted(tokens)
47
+
48
+ # Build vocab dictionaries
49
+ self.vocab = {tok: idx for idx, tok in enumerate(all_tokens)}
50
+ self.token_to_id = self.vocab
51
+ self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
52
+
53
+ print(f"Vocabulary size: {len(self.vocab)}")
54
+
55
+ def encode(self, text):
56
+ tokens = self.tokenize(text)
57
+ encoded = []
58
+ for tok in tokens:
59
+ if tok not in self.token_to_id:
60
+ raise ValueError(f"Unknown token encountered: {tok} in {text}")
61
+ encoded.append(self.token_to_id[tok])
62
+ return encoded
63
+
64
+ def encode_batch(self, texts, pad_to_length=None):
65
+ """
66
+ Encode a batch of texts into token IDs with padding to ensure uniform length.
67
+
68
+ Args:
69
+ texts (list): A list of strings to encode
70
+ pad_to_length (int, optional): Length to pad all sequences to. If None,
71
+ will pad to the length of the longest sequence.
72
+
73
+ Returns:
74
+ list: A list of lists, where each inner list contains the token IDs for a text
75
+ """
76
+ # Get the padding token ID
77
+ pad_token = self.token_to_id["[PAD]"]
78
+
79
+ # First encode all texts
80
+ encoded_texts = []
81
+ for text in texts:
82
+ try:
83
+ encoded = self.encode(text)
84
+ encoded_texts.append(encoded)
85
+ except ValueError as e:
86
+ raise ValueError(f"Error encoding text: {text}. {str(e)}")
87
+
88
+ # Determine padding length
89
+ if pad_to_length is None:
90
+ pad_to_length = max(len(seq) for seq in encoded_texts)
91
+
92
+ # Pad sequences to uniform length
93
+ padded_texts = []
94
+ for seq in encoded_texts:
95
+ if len(seq) > pad_to_length:
96
+ # Truncate if too long
97
+ padded_texts.append(seq[:pad_to_length])
98
+ else:
99
+ # Pad if too short
100
+ padding = [pad_token] * (pad_to_length - len(seq))
101
+ padded_texts.append(seq + padding)
102
+
103
+ return padded_texts
104
+
105
+ def decode(self, token_ids):
106
+ return ' '.join(self.id_to_token[tok_id] for tok_id in token_ids)
107
+
108
+ def save(self, path):
109
+ with open(path, 'wb') as f:
110
+ pickle.dump({'vocab': self.vocab}, f)
111
+
112
+ def load(self, path):
113
+ with open(path, 'rb') as f:
114
+ data = pickle.load(f)
115
+ self.vocab = data['vocab']
116
+ self.token_to_id = self.vocab
117
+ self.id_to_token = {idx: tok for tok, idx in self.vocab.items()}
118
+
119
+ def get_vocab(self):
120
+ return sorted(self.vocab.keys())
121
+
122
+ def get_vocab_size(self):
123
+ return len(self.vocab)
124
+
125
+ if __name__ == "__main__":
126
+ tokenizer = Tokenizer()
127
+
128
+ parser = argparse.ArgumentParser(description="Tokenizer utility for saving and loading vocabularies.")
129
+ parser.add_argument("action", choices=["save", "load"], help="Action to perform: 'save' or 'load'.")
130
+ parser.add_argument("--json_file", type=str, default='Mario_LevelsAndCaptions.json', help="Path to the JSON file containing the dataset (required for 'save').")
131
+ parser.add_argument("--pkl_file", type=str, default='Mario_Tokenizer.pkl', help="Path to the pickle file to save/load the tokenizer.")
132
+
133
+ args = parser.parse_args()
134
+
135
+ if args.action == "save":
136
+ if not args.json_file:
137
+ raise ValueError("The --json_file argument is required for the 'save' action.")
138
+ tokenizer.build_vocab(args.json_file)
139
+ tokenizer.save(args.pkl_file)
140
+ elif args.action == "load":
141
+ tokenizer.load(args.pkl_file)
142
+
143
+ # Example usage
144
+ #print(tokenizer.encode("floor with one gap. one enemy."))
145
+ #print(tokenizer.get_vocab())
146
+ #for id, token in tokenizer.id_to_token.items():
147
+ # print(id,":",token)
util/common_settings.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ NUM_INFERENCE_STEPS = 30
3
+ GUIDANCE_SCALE = 7.5
4
+
5
+ MARIO_HEIGHT = 16
6
+ MARIO_WIDTH = 16
7
+
8
+ MARIO_TILE_PIXEL_DIM = 16
9
+ MARIO_TILE_COUNT = 13
10
+
11
+ LR_HEIGHT = 32
12
+ LR_WIDTH = 32
13
+
14
+ LR_TILE_PIXEL_DIM = 8
15
+ LR_TILE_COUNT = 8
16
+
17
+ MEGAMAN_HEIGHT = 14
18
+ MEGAMAN_WIDTH = 16