ironjr commited on
Commit
e533760
1 Parent(s): 927e766

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -9
model.py CHANGED
@@ -1121,15 +1121,24 @@ class StreamMultiDiffusion(nn.Module):
1121
  else:
1122
  x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
1123
 
1124
- try:
1125
- model_pred = self.unet(
1126
- x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
1127
- t_list, # (B,)
1128
- encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1129
- return_dict=False,
1130
- )[0] # (B, 4, h, w)
1131
- except Exception as e:
1132
- print(e)
 
 
 
 
 
 
 
 
 
1133
  print('222222222222222', model_pred.dtype)
1134
 
1135
  if self.bootstrap_steps[0] > 0:
 
1121
  else:
1122
  x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
1123
 
1124
+ ns = []
1125
+ c1, c2, c3 = 0, 0, 0
1126
+ for n, p in self.unet.named_parameters():
1127
+ if p.data.dtype == torch.float:
1128
+ c1 += 1
1129
+ ns.append(n)
1130
+ elif p.data.dtype == torch.half:
1131
+ c2 += 1
1132
+ else:
1133
+ c3 += 1
1134
+ print(c1, c2, c3)
1135
+ print(ns)
1136
+ model_pred = self.unet(
1137
+ x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
1138
+ t_list, # (B,)
1139
+ encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1140
+ return_dict=False,
1141
+ )[0] # (B, 4, h, w)
1142
  print('222222222222222', model_pred.dtype)
1143
 
1144
  if self.bootstrap_steps[0] > 0: