jadechoghari commited on
Commit
80a6dc1
1 Parent(s): 1df7b61

add safetensors

Browse files
Files changed (1) hide show
  1. pipeline.py +10 -10
pipeline.py CHANGED
@@ -6,8 +6,7 @@ import sys
6
  from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
8
  import os
9
- from torchvision.utils import save_image
10
- from PIL import Image
11
  from .vae import AutoencoderKL
12
  from .mar import mar_base, mar_large, mar_huge
13
 
@@ -46,20 +45,22 @@ class MARModel(DiffusionPipeline):
46
  if model_type == "mar_base":
47
  diffloss_d = 6
48
  diffloss_w = 1024
 
49
  elif model_type == "mar_large":
50
  diffloss_d = 8
51
  diffloss_w = 1280
 
52
  elif model_type == "mar_huge":
53
  diffloss_d = 12
54
  diffloss_w = 1536
 
55
  else:
56
  raise NotImplementedError
57
  # download and load the model weights (.safetensors or .pth)
58
  model_checkpoint_path = hf_hub_download(
59
  repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
60
- filename=kwargs.get("model_filename", "checkpoint-last.pth")
61
  )
62
- model_checkpoint_path = kwargs.get("model_checkpoint_path", "./mar/checkpoint-last.pth")
63
 
64
  model_fn = model_mapping[model_type]
65
 
@@ -70,7 +71,8 @@ class MARModel(DiffusionPipeline):
70
  num_sampling_steps=str(num_sampling_steps_diffloss)
71
  ).cuda()
72
 
73
- state_dict = torch.load(f"./mar/checkpoint-last.pth")["model_ema"]
 
74
  model.load_state_dict(state_dict)
75
  model.eval()
76
 
@@ -85,7 +87,7 @@ class MARModel(DiffusionPipeline):
85
  vae = vae.to(device).eval()
86
 
87
  # set up user-specified or default values for generation
88
- seed = kwargs.get("seed", 0)
89
  torch.manual_seed(seed)
90
  np.random.seed(seed)
91
 
@@ -93,9 +95,7 @@ class MARModel(DiffusionPipeline):
93
  cfg_scale = kwargs.get("cfg_scale", 4)
94
  cfg_schedule = kwargs.get("cfg_schedule", "constant")
95
  temperature = kwargs.get("temperature", 1.0)
96
- # class_labels = kwargs.get("class_labels", 207, 360, 388, 113, 355, 980, 323, 979)
97
- class_labels = 207, 360, 388, 113, 355, 980, 323, 979
98
- print("the labels", class_labels)
99
 
100
  # generate the tokens and images
101
  with torch.cuda.amp.autocast():
@@ -113,7 +113,7 @@ class MARModel(DiffusionPipeline):
113
 
114
  # save the images
115
  image_path = os.path.join(output_dir, "sampled_image.png")
116
- samples_per_row = kwargs.get("samples_per_row", 6)
117
 
118
  save_image(
119
  sampled_images, image_path, nrow=int(samples_per_row), normalize=True, value_range=(-1, 1)
 
6
  from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
8
  import os
9
+ from safetensors.torch import load_file
 
10
  from .vae import AutoencoderKL
11
  from .mar import mar_base, mar_large, mar_huge
12
 
 
45
  if model_type == "mar_base":
46
  diffloss_d = 6
47
  diffloss_w = 1024
48
+ model_path = "mar-base.safetensors"
49
  elif model_type == "mar_large":
50
  diffloss_d = 8
51
  diffloss_w = 1280
52
+ model_path = "mar-large.safetensors"
53
  elif model_type == "mar_huge":
54
  diffloss_d = 12
55
  diffloss_w = 1536
56
+ model_path = "mar-huge.safetensors"
57
  else:
58
  raise NotImplementedError
59
  # download and load the model weights (.safetensors or .pth)
60
  model_checkpoint_path = hf_hub_download(
61
  repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
62
+ filename=kwargs.get("model_filename", model_path)
63
  )
 
64
 
65
  model_fn = model_mapping[model_type]
66
 
 
71
  num_sampling_steps=str(num_sampling_steps_diffloss)
72
  ).cuda()
73
 
74
+ # use safetensors
75
+ state_dict = load_file(safetensors_path)
76
  model.load_state_dict(state_dict)
77
  model.eval()
78
 
 
87
  vae = vae.to(device).eval()
88
 
89
  # set up user-specified or default values for generation
90
+ seed = kwargs.get("seed", 6)
91
  torch.manual_seed(seed)
92
  np.random.seed(seed)
93
 
 
95
  cfg_scale = kwargs.get("cfg_scale", 4)
96
  cfg_schedule = kwargs.get("cfg_schedule", "constant")
97
  temperature = kwargs.get("temperature", 1.0)
98
+ class_labels = kwargs.get("class_labels", 207, 360, 388, 113, 355, 980, 323, 979)
 
 
99
 
100
  # generate the tokens and images
101
  with torch.cuda.amp.autocast():
 
113
 
114
  # save the images
115
  image_path = os.path.join(output_dir, "sampled_image.png")
116
+ samples_per_row = kwargs.get("samples_per_row", 4)
117
 
118
  save_image(
119
  sampled_images, image_path, nrow=int(samples_per_row), normalize=True, value_range=(-1, 1)