Sapir Weissbuch commited on
Commit
77d3abf
·
unverified ·
2 Parent(s): fc02e02 e46ff5e

Merge pull request #5 from LightricksResearch/safetensors-ckpts

Browse files
scripts/to_safetensors.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ from typing import Any, Dict
4
+ import safetensors.torch
5
+ import torch
6
+ import json
7
+ import shutil
8
+
9
+
10
+ def load_text_encoder(index_path: Path) -> Dict:
11
+ with open(index_path, 'r') as f:
12
+ index: Dict = json.load(f)
13
+
14
+ loaded_tensors = {}
15
+ for part_file in set(index.get("weight_map", {}).values()):
16
+ tensors = safetensors.torch.load_file(index_path.parent / part_file, device='cpu')
17
+ for tensor_name in tensors:
18
+ loaded_tensors[tensor_name] = tensors[tensor_name]
19
+
20
+ return loaded_tensors
21
+
22
+
23
+ def convert_unet(unet: Dict, add_prefix=True) -> Dict:
24
+ if add_prefix:
25
+ return {"model.diffusion_model." + key: value for key, value in unet.items()}
26
+ return unet
27
+
28
+
29
+ def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
30
+ state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
31
+ stats_path = vae_path / "per_channel_statistics.json"
32
+ if stats_path.exists():
33
+ with open(stats_path, 'r') as f:
34
+ data = json.load(f)
35
+ transposed_data = list(zip(*data["data"]))
36
+ data_dict = {
37
+ f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(vals)
38
+ for col, vals in zip(data["columns"], transposed_data)
39
+ }
40
+ else:
41
+ data_dict = {}
42
+
43
+ result = {("vae." if add_prefix else "") + key: value for key, value in state_dict.items()}
44
+ result.update(data_dict)
45
+ return result
46
+
47
+
48
+ def convert_encoder(encoder: Dict) -> Dict:
49
+ return {"text_encoders.t5xxl.transformer." + key: value for key, value in encoder.items()}
50
+
51
+
52
+ def save_config(config_src: str, config_dst: str):
53
+ shutil.copy(config_src, config_dst)
54
+
55
+
56
+ def load_vae_config(vae_path: Path) -> str:
57
+ config_path = vae_path / "config.json"
58
+ if not config_path.exists():
59
+ raise FileNotFoundError(f"VAE config file {config_path} not found.")
60
+ return str(config_path)
61
+
62
+
63
+ def main(unet_path: str, vae_path: str, out_path: str, mode: str,
64
+ unet_config_path: str = None, scheduler_config_path: str = None) -> None:
65
+ unet = convert_unet(torch.load(unet_path, weights_only=True), add_prefix=(mode == 'single'))
66
+
67
+ # Load VAE from directory and config
68
+ vae = convert_vae(Path(vae_path), add_prefix=(mode == 'single'))
69
+ vae_config_path = load_vae_config(Path(vae_path))
70
+
71
+ if mode == 'single':
72
+ result = {**unet, **vae}
73
+ safetensors.torch.save_file(result, out_path)
74
+ elif mode == 'separate':
75
+ # Create directories for unet, vae, and scheduler
76
+ unet_dir = Path(out_path) / 'unet'
77
+ vae_dir = Path(out_path) / 'vae'
78
+ scheduler_dir = Path(out_path) / 'scheduler'
79
+
80
+ unet_dir.mkdir(parents=True, exist_ok=True)
81
+ vae_dir.mkdir(parents=True, exist_ok=True)
82
+ scheduler_dir.mkdir(parents=True, exist_ok=True)
83
+
84
+ # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
85
+ safetensors.torch.save_file(unet, unet_dir / 'diffusion_pytorch_model.safetensors')
86
+ safetensors.torch.save_file(vae, vae_dir / 'diffusion_pytorch_model.safetensors')
87
+
88
+ # Save config files for unet, vae, and scheduler
89
+ if unet_config_path:
90
+ save_config(unet_config_path, unet_dir / 'config.json')
91
+ if vae_config_path:
92
+ save_config(vae_config_path, vae_dir / 'config.json')
93
+ if scheduler_config_path:
94
+ save_config(scheduler_config_path, scheduler_dir / 'scheduler_config.json')
95
+
96
+
97
+ if __name__ == '__main__':
98
+ parser = argparse.ArgumentParser()
99
+ parser.add_argument('--unet_path', '-u', type=str, default='unet/ema-002.pt')
100
+ parser.add_argument('--vae_path', '-v', type=str, default='vae/')
101
+ parser.add_argument('--out_path', '-o', type=str, default='xora.safetensors')
102
+ parser.add_argument('--mode', '-m', type=str, choices=['single', 'separate'], default='single',
103
+ help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.")
104
+ parser.add_argument('--unet_config_path', type=str, help="Path to the UNet config file (for separate mode)")
105
+ parser.add_argument('--scheduler_config_path', type=str,
106
+ help="Path to the Scheduler config file (for separate mode)")
107
+
108
+ args = parser.parse_args()
109
+ main(**args.__dict__)
xora/examples/image_to_video.py CHANGED
@@ -5,83 +5,107 @@ from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
  from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
- from transformers import T5EncoderModel
9
-
10
-
11
- model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
12
- vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000")
13
- dtype = torch.float32
14
- vae = CausalVideoAutoencoder.from_pretrained(
15
- pretrained_model_name_or_path=vae_local_path,
16
- revision=False,
17
- torch_dtype=torch.bfloat16,
18
- load_in_8bit=False,
19
- ).cuda()
20
- transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json")
21
- transformer_config = Transformer3DModel.load_config(transformer_config_path)
22
- transformer = Transformer3DModel.from_config(transformer_config)
23
- transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-first-frame-cond-4k-seq/ckpt/01822000/model.pt")
24
- transformer_ckpt_state_dict = torch.load(transformer_local_path)
25
- transformer.load_state_dict(transformer_ckpt_state_dict, True)
26
- transformer = transformer.cuda()
27
- unet = transformer
28
- scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json")
29
- scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
30
- scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
31
- patchifier = SymmetricPatchifier(patch_size=1)
32
- # text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl")
33
-
34
- submodel_dict = {
35
- "unet": unet,
36
- "transformer": transformer,
37
- "patchifier": patchifier,
38
- "text_encoder": None,
39
- "scheduler": scheduler,
40
- "vae": vae,
41
-
42
- }
43
-
44
- pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
45
- safety_checker=None,
46
- revision=None,
47
- torch_dtype=dtype,
48
- **submodel_dict,
49
- )
50
-
51
- num_inference_steps=20
52
- num_images_per_prompt=2
53
- guidance_scale=3
54
- height=512
55
- width=768
56
- num_frames=57
57
- frame_rate=25
58
- # sample = {
59
- # "prompt": "A cat", # (B, L, E)
60
- # 'prompt_attention_mask': None, # (B , L)
61
- # 'negative_prompt': "Ugly deformed",
62
- # 'negative_prompt_attention_mask': None # (B , L)
63
- # }
64
-
65
- sample = torch.load("/opt/sample.pt")
66
- for _, item in sample.items():
67
- if item is not None:
68
- item = item.cuda()
69
- media_items = torch.load("/opt/sample_media.pt")
70
-
71
- images = pipeline(
72
- num_inference_steps=num_inference_steps,
73
- num_images_per_prompt=num_images_per_prompt,
74
- guidance_scale=guidance_scale,
75
- generator=None,
76
- output_type="pt",
77
- callback_on_step_end=None,
78
- height=height,
79
- width=width,
80
- num_frames=num_frames,
81
- frame_rate=frame_rate,
82
- **sample,
83
- is_video=True,
84
- vae_per_channel_normalize=True,
85
- ).images
86
-
87
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
  from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
+ from transformers import T5EncoderModel, T5Tokenizer
9
+ import safetensors.torch
10
+ import json
11
+ import argparse
12
+
13
+ def load_vae(vae_dir):
14
+ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
15
+ vae_config_path = vae_dir / "config.json"
16
+ with open(vae_config_path, 'r') as f:
17
+ vae_config = json.load(f)
18
+ vae = CausalVideoAutoencoder.from_config(vae_config)
19
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
20
+ vae.load_state_dict(vae_state_dict)
21
+ return vae.cuda().to(torch.bfloat16)
22
+
23
+ def load_unet(unet_dir):
24
+ unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
25
+ unet_config_path = unet_dir / "config.json"
26
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
27
+ transformer = Transformer3DModel.from_config(transformer_config)
28
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
29
+ transformer.load_state_dict(unet_state_dict, strict=True)
30
+ return transformer.cuda()
31
+
32
+ def load_scheduler(scheduler_dir):
33
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
34
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
35
+ return RectifiedFlowScheduler.from_config(scheduler_config)
36
+
37
+ def main():
38
+ # Parse command line arguments
39
+ parser = argparse.ArgumentParser(description='Load models from separate directories')
40
+ parser.add_argument('--separate_dir', type=str, required=True, help='Path to the directory containing unet, vae, and scheduler subdirectories')
41
+ args = parser.parse_args()
42
+
43
+ # Paths for the separate mode directories
44
+ separate_dir = Path(args.separate_dir)
45
+ unet_dir = separate_dir / 'unet'
46
+ vae_dir = separate_dir / 'vae'
47
+ scheduler_dir = separate_dir / 'scheduler'
48
+
49
+ # Load models
50
+ vae = load_vae(vae_dir)
51
+ unet = load_unet(unet_dir)
52
+ scheduler = load_scheduler(scheduler_dir)
53
+
54
+ # Patchifier (remains the same)
55
+ patchifier = SymmetricPatchifier(patch_size=1)
56
+
57
+ # text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to("cuda")
58
+ # tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
59
+
60
+ # Use submodels for the pipeline
61
+ submodel_dict = {
62
+ "transformer": unet, # using unet for transformer
63
+ "patchifier": patchifier,
64
+ "text_encoder": None,
65
+ "tokenizer": None,
66
+ "scheduler": scheduler,
67
+ "vae": vae,
68
+ }
69
+
70
+ model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
71
+ pipeline = VideoPixArtAlphaPipeline(
72
+ **submodel_dict
73
+ ).to("cuda")
74
+
75
+ num_inference_steps = 20
76
+ num_images_per_prompt = 1
77
+ guidance_scale = 3
78
+ height = 512
79
+ width = 768
80
+ num_frames = 57
81
+ frame_rate = 25
82
+
83
+ # Sample input stays the same
84
+ sample = torch.load("/opt/sample_media.pt")
85
+ for key, item in sample.items():
86
+ if item is not None:
87
+ sample[key] = item.cuda()
88
+
89
+ # media_items = torch.load("/opt/sample_media.pt")
90
+
91
+ # Generate images (video frames)
92
+ images = pipeline(
93
+ num_inference_steps=num_inference_steps,
94
+ num_images_per_prompt=num_images_per_prompt,
95
+ guidance_scale=guidance_scale,
96
+ generator=None,
97
+ output_type="pt",
98
+ callback_on_step_end=None,
99
+ height=height,
100
+ width=width,
101
+ num_frames=num_frames,
102
+ frame_rate=frame_rate,
103
+ **sample,
104
+ is_video=True,
105
+ vae_per_channel_normalize=True,
106
+ ).images
107
+
108
+ print("Generated video frames.")
109
+
110
+ if __name__ == "__main__":
111
+ main()
xora/examples/text_to_video.py CHANGED
@@ -5,84 +5,104 @@ from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
  from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
- from transformers import T5EncoderModel
 
 
 
9
 
 
 
 
 
 
 
 
 
 
10
 
11
- model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
12
- vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000")
13
- dtype = torch.float32
14
- vae = CausalVideoAutoencoder.from_pretrained(
15
- pretrained_model_name_or_path=vae_local_path,
16
- revision=False,
17
- torch_dtype=torch.bfloat16,
18
- load_in_8bit=False,
19
- ).cuda()
20
- transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json")
21
- transformer_config = Transformer3DModel.load_config(transformer_config_path)
22
- transformer = Transformer3DModel.from_config(transformer_config)
23
- transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-nl/ckpt/01760000/model.pt")
24
- transformer_ckpt_state_dict = torch.load(transformer_local_path)
25
- transformer.load_state_dict(transformer_ckpt_state_dict, True)
26
- transformer = transformer.cuda()
27
- unet = transformer
28
- scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json")
29
- scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
30
- scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
31
- patchifier = SymmetricPatchifier(patch_size=1)
32
- # text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl")
33
 
34
- submodel_dict = {
35
- "unet": unet,
36
- "transformer": transformer,
37
- "patchifier": patchifier,
38
- "text_encoder": None,
39
- "scheduler": scheduler,
40
- "vae": vae,
41
 
42
- }
 
 
 
 
43
 
44
- pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
45
- safety_checker=None,
46
- revision=None,
47
- torch_dtype=dtype,
48
- **submodel_dict,
49
- )
50
 
51
- num_inference_steps=20
52
- num_images_per_prompt=2
53
- guidance_scale=3
54
- height=512
55
- width=768
56
- num_frames=57
57
- frame_rate=25
58
- # sample = {
59
- # "prompt": "A cat", # (B, L, E)
60
- # 'prompt_attention_mask': None, # (B , L)
61
- # 'negative_prompt': "Ugly deformed",
62
- # 'negative_prompt_attention_mask': None # (B , L)
63
- # }
64
 
65
- sample = torch.load("/opt/sample.pt")
66
- for _, item in sample.items():
67
- if item is not None:
68
- item = item.cuda()
69
 
 
 
70
 
 
 
 
 
 
 
 
 
 
71
 
72
- images = pipeline(
73
- num_inference_steps=num_inference_steps,
74
- num_images_per_prompt=num_images_per_prompt,
75
- guidance_scale=guidance_scale,
76
- generator=None,
77
- output_type="pt",
78
- callback_on_step_end=None,
79
- height=height,
80
- width=width,
81
- num_frames=num_frames,
82
- frame_rate=frame_rate,
83
- **sample,
84
- is_video=True,
85
- vae_per_channel_normalize=True,
86
- ).images
87
 
88
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from xora.schedulers.rf import RectifiedFlowScheduler
6
  from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
  from pathlib import Path
8
+ from transformers import T5EncoderModel, T5Tokenizer
9
+ import safetensors.torch
10
+ import json
11
+ import argparse
12
 
13
+ def load_vae(vae_dir):
14
+ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
15
+ vae_config_path = vae_dir / "config.json"
16
+ with open(vae_config_path, 'r') as f:
17
+ vae_config = json.load(f)
18
+ vae = CausalVideoAutoencoder.from_config(vae_config)
19
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
20
+ vae.load_state_dict(vae_state_dict)
21
+ return vae.cuda().to(torch.bfloat16)
22
 
23
+ def load_unet(unet_dir):
24
+ unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
25
+ unet_config_path = unet_dir / "config.json"
26
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
27
+ transformer = Transformer3DModel.from_config(transformer_config)
28
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
29
+ transformer.load_state_dict(unet_state_dict, strict=True)
30
+ return transformer.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ def load_scheduler(scheduler_dir):
33
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
34
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
35
+ return RectifiedFlowScheduler.from_config(scheduler_config)
 
 
 
36
 
37
+ def main():
38
+ # Parse command line arguments
39
+ parser = argparse.ArgumentParser(description='Load models from separate directories')
40
+ parser.add_argument('--separate_dir', type=str, required=True, help='Path to the directory containing unet, vae, and scheduler subdirectories')
41
+ args = parser.parse_args()
42
 
43
+ # Paths for the separate mode directories
44
+ separate_dir = Path(args.separate_dir)
45
+ unet_dir = separate_dir / 'unet'
46
+ vae_dir = separate_dir / 'vae'
47
+ scheduler_dir = separate_dir / 'scheduler'
 
48
 
49
+ # Load models
50
+ vae = load_vae(vae_dir)
51
+ unet = load_unet(unet_dir)
52
+ scheduler = load_scheduler(scheduler_dir)
 
 
 
 
 
 
 
 
 
53
 
54
+ # Patchifier (remains the same)
55
+ patchifier = SymmetricPatchifier(patch_size=1)
 
 
56
 
57
+ text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to("cuda")
58
+ tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
59
 
60
+ # Use submodels for the pipeline
61
+ submodel_dict = {
62
+ "transformer": unet, # using unet for transformer
63
+ "patchifier": patchifier,
64
+ "scheduler": scheduler,
65
+ "text_encoder": text_encoder,
66
+ "tokenizer": tokenizer,
67
+ "vae": vae,
68
+ }
69
 
70
+ pipeline = VideoPixArtAlphaPipeline(**submodel_dict).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # Sample input
73
+ num_inference_steps = 20
74
+ num_images_per_prompt = 2
75
+ guidance_scale = 3
76
+ height = 512
77
+ width = 768
78
+ num_frames = 57
79
+ frame_rate = 25
80
+ sample = {
81
+ "prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
82
+ "The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
83
+ 'prompt_attention_mask': None, # Adjust attention masks as needed
84
+ 'negative_prompt': "Ugly deformed",
85
+ 'negative_prompt_attention_mask': None
86
+ }
87
+
88
+ # Generate images (video frames)
89
+ images = pipeline(
90
+ num_inference_steps=num_inference_steps,
91
+ num_images_per_prompt=num_images_per_prompt,
92
+ guidance_scale=guidance_scale,
93
+ generator=None,
94
+ output_type="pt",
95
+ callback_on_step_end=None,
96
+ height=height,
97
+ width=width,
98
+ num_frames=num_frames,
99
+ frame_rate=frame_rate,
100
+ **sample,
101
+ is_video=True,
102
+ vae_per_channel_normalize=True,
103
+ ).images
104
+
105
+ print("Generated images (video frames).")
106
+
107
+ if __name__ == "__main__":
108
+ main()
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -126,6 +126,13 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
126
  return json.dumps(self.config.__dict__)
127
 
128
  def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
 
 
 
 
 
 
 
129
  model_keys = set(name for name, _ in self.named_parameters())
130
 
131
  key_mapping = {
@@ -133,9 +140,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
133
  "downsamplers.0": "downsample",
134
  "upsamplers.0": "upsample",
135
  }
136
-
137
  converted_state_dict = {}
138
- for key, value in state_dict.items():
139
  for k, v in key_mapping.items():
140
  key = key.replace(k, v)
141
 
@@ -147,6 +153,20 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
147
 
148
  super().load_state_dict(converted_state_dict, strict=strict)
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def last_layer(self):
151
  if hasattr(self.decoder, "conv_out"):
152
  if isinstance(self.decoder.conv_out, nn.Sequential):
 
126
  return json.dumps(self.config.__dict__)
127
 
128
  def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
129
+ per_channel_statistics_prefix = "per_channel_statistics."
130
+ ckpt_state_dict = {
131
+ key: value
132
+ for key, value in state_dict.items()
133
+ if not key.startswith(per_channel_statistics_prefix)
134
+ }
135
+
136
  model_keys = set(name for name, _ in self.named_parameters())
137
 
138
  key_mapping = {
 
140
  "downsamplers.0": "downsample",
141
  "upsamplers.0": "upsample",
142
  }
 
143
  converted_state_dict = {}
144
+ for key, value in ckpt_state_dict.items():
145
  for k, v in key_mapping.items():
146
  key = key.replace(k, v)
147
 
 
153
 
154
  super().load_state_dict(converted_state_dict, strict=strict)
155
 
156
+ data_dict = {
157
+ key.removeprefix(per_channel_statistics_prefix): value
158
+ for key, value in state_dict.items()
159
+ if key.startswith(per_channel_statistics_prefix)
160
+ }
161
+ if len(data_dict) > 0:
162
+ self.register_buffer("std_of_means", data_dict["std-of-means"])
163
+ self.register_buffer(
164
+ "mean_of_means",
165
+ data_dict.get(
166
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
167
+ ),
168
+ )
169
+
170
  def last_layer(self):
171
  if hasattr(self.decoder, "conv_out"):
172
  if isinstance(self.decoder.conv_out, nn.Sequential):