wondervictor commited on
Commit
e4507cb
·
verified ·
1 Parent(s): c6ff29a

Update autoregressive/models/gpt_t2i.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/gpt_t2i.py +3 -3
autoregressive/models/gpt_t2i.py CHANGED
@@ -434,9 +434,9 @@ class Transformer(nn.Module):
434
  if condition is not None:
435
  condition_embeddings = self.condition_mlp(condition,train=self.training)#.to(torch.bfloat16),train=self.training)
436
  self.condition_token = condition_embeddings
437
- self.condition_token = [self.condition_layer[0](self.condition_token),
438
- self.condition_layer[1](self.condition_token),
439
- self.condition_layer[2](self.condition_token)]
440
 
441
  else: # decode_n_tokens(kv cache) in inference
442
  token_embeddings = self.tok_embeddings(idx)
 
434
  if condition is not None:
435
  condition_embeddings = self.condition_mlp(condition,train=self.training)#.to(torch.bfloat16),train=self.training)
436
  self.condition_token = condition_embeddings
437
+ self.condition_token = [self.condition_layers[0](self.condition_token),
438
+ self.condition_layers[1](self.condition_token),
439
+ self.condition_layers[2](self.condition_token)]
440
 
441
  else: # decode_n_tokens(kv cache) in inference
442
  token_embeddings = self.tok_embeddings(idx)