rynmurdock commited on
Commit
8d3c278
1 Parent(s): 8ed9e1d
Files changed (2) hide show
  1. app.py +3 -1
  2. requirements.txt +1 -0
app.py CHANGED
@@ -125,7 +125,7 @@ def text_from_latent_code(latent_z):
125
  model=model_vae.decoder,
126
  context=context_tokens,
127
  past=past,
128
- length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
129
  temperature=.5,
130
  top_k=100,
131
  top_p=.98,
@@ -190,6 +190,8 @@ model_vae.load_state_dict(checkpoint['model_state_dict'])
190
  print("Pre-trained Optimus is successfully loaded")
191
  model_vae.to(DEVICE).to(torch.bfloat16)
192
  model_vae = torch.compile(model_vae)
 
 
193
 
194
  l = latent_code_from_text('A photo of a mountain.')[0]
195
  t = text_from_latent_code(l)
 
125
  model=model_vae.decoder,
126
  context=context_tokens,
127
  past=past,
128
+ length=length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
129
  temperature=.5,
130
  top_k=100,
131
  top_p=.98,
 
190
  print("Pre-trained Optimus is successfully loaded")
191
  model_vae.to(DEVICE).to(torch.bfloat16)
192
  model_vae = torch.compile(model_vae)
193
+ model_vae.encoder = torch.compile(model_vae.encoder)
194
+ model_vae.decoder = torch.compile(model_vae.decoder)
195
 
196
  l = latent_code_from_text('A photo of a mountain.')[0]
197
  t = text_from_latent_code(l)
requirements.txt CHANGED
@@ -12,3 +12,4 @@ sentencepiece
12
  peft
13
  tensorflow_hub
14
  tensorflow==2.14.0
 
 
12
  peft
13
  tensorflow_hub
14
  tensorflow==2.14.0
15
+ sacremoses