wondervictor commited on
Commit
6b11bf0
·
verified ·
1 Parent(s): d9a7d69

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +2 -0
autoregressive/models/generate.py CHANGED
@@ -143,7 +143,9 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
143
  print(model.adapter.model.embeddings.patch_embeddings.projection.weight)
144
  condition = model.adapter(condition)
145
  print(condition)
 
146
  condition = model.adapter_mlp(condition)
 
147
  if model.model_type == 'c2i':
148
  if cfg_scale > 1.0:
149
  cond_null = torch.ones_like(cond) * model.num_classes
 
143
  print(model.adapter.model.embeddings.patch_embeddings.projection.weight)
144
  condition = model.adapter(condition)
145
  print(condition)
146
+ condition = torch.ones_like(condition)
147
  condition = model.adapter_mlp(condition)
148
+ print(condition)
149
  if model.model_type == 'c2i':
150
  if cfg_scale > 1.0:
151
  cond_null = torch.ones_like(cond) * model.num_classes