daniel shalem commited on
Commit
6a9d9a1
·
1 Parent(s): 91602f9

Feature: Add full bfloat16 support.

Browse files
xora/examples/image_to_video.py CHANGED
@@ -142,6 +142,12 @@ def main():
142
  help="Mixed precision in float32 and bfloat16",
143
  )
144
 
 
 
 
 
 
 
145
  # Prompts
146
  parser.add_argument(
147
  "--prompt",
@@ -176,6 +182,9 @@ def main():
176
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
177
  )
178
 
 
 
 
179
  # Use submodels for the pipeline
180
  submodel_dict = {
181
  "transformer": unet,
 
142
  help="Mixed precision in float32 and bfloat16",
143
  )
144
 
145
+ parser.add_argument(
146
+ "--bfloat16",
147
+ action="store_true",
148
+ help="Denoise in bfloat16",
149
+ )
150
+
151
  # Prompts
152
  parser.add_argument(
153
  "--prompt",
 
182
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
183
  )
184
 
185
+ if args.bfloat16 and unet.dtype != torch.bfloat16:
186
+ unet = unet.to(torch.bfloat16)
187
+
188
  # Use submodels for the pipeline
189
  submodel_dict = {
190
  "transformer": unet,
xora/examples/text_to_video.py CHANGED
@@ -49,6 +49,16 @@ def main():
49
  required=True,
50
  help="Path to the directory containing unet, vae, and scheduler subdirectories",
51
  )
 
 
 
 
 
 
 
 
 
 
52
  args = parser.parse_args()
53
 
54
  # Paths for the separate mode directories
@@ -72,6 +82,9 @@ def main():
72
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
73
  )
74
 
 
 
 
75
  # Use submodels for the pipeline
76
  submodel_dict = {
77
  "transformer": unet, # using unet for transformer
@@ -115,6 +128,7 @@ def main():
115
  **sample,
116
  is_video=True,
117
  vae_per_channel_normalize=True,
 
118
  ).images
119
 
120
  print("Generated images (video frames).")
 
49
  required=True,
50
  help="Path to the directory containing unet, vae, and scheduler subdirectories",
51
  )
52
+ parser.add_argument(
53
+ "--mixed_precision",
54
+ action="store_true",
55
+ help="Mixed precision in float32 and bfloat16",
56
+ )
57
+ parser.add_argument(
58
+ "--bfloat16",
59
+ action="store_true",
60
+ help="Denoise in bfloat16",
61
+ )
62
  args = parser.parse_args()
63
 
64
  # Paths for the separate mode directories
 
82
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
83
  )
84
 
85
+ if args.bfloat16 and unet.dtype != torch.bfloat16:
86
+ unet = unet.to(torch.bfloat16)
87
+
88
  # Use submodels for the pipeline
89
  submodel_dict = {
90
  "transformer": unet, # using unet for transformer
 
128
  **sample,
129
  is_video=True,
130
  vae_per_channel_normalize=True,
131
+ mixed_precision=args.mixed_precision,
132
  ).images
133
 
134
  print("Generated images (video frames).")
xora/models/transformers/transformer3d.py CHANGED
@@ -253,7 +253,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
253
  return fractional_positions
254
 
255
  def precompute_freqs_cis(self, indices_grid, spacing="exp"):
256
- dtype = self.dtype
257
  dim = self.inner_dim
258
  theta = self.positional_embedding_theta
259
 
@@ -305,7 +305,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
305
  sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
306
  cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
307
  sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
308
- return cos_freq.to(dtype), sin_freq.to(dtype)
309
 
310
  def forward(
311
  self,
 
253
  return fractional_positions
254
 
255
  def precompute_freqs_cis(self, indices_grid, spacing="exp"):
256
+ dtype = torch.float32 # We need full precision in the freqs_cis computation.
257
  dim = self.inner_dim
258
  theta = self.positional_embedding_theta
259
 
 
305
  sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
306
  cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
307
  sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
308
+ return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
309
 
310
  def forward(
311
  self,