ybelkada HF staff commited on
Commit
63d75f6
1 Parent(s): da8e0b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -17,12 +17,19 @@ model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3"
17
 
18
  def inference(text, seq_length=1):
19
  input_ids = tokenizer(text, return_tensors='pt')['input_ids']
 
20
  with torch.inference_mode(), model.transformer.h.inference_session() as remote_transformer:
21
  for i in range(seq_length):
22
  h = model.transformer.word_embeddings(input_ids)
23
  h = model.transformer.word_embeddings_layernorm(h)
24
  h = remote_transformer.step(h)
25
- return repr(h)
26
-
 
 
 
 
 
 
27
  iface = gr.Interface(fn=inference, inputs="text", outputs="text")
28
  iface.launch()
 
17
 
18
  def inference(text, seq_length=1):
19
  input_ids = tokenizer(text, return_tensors='pt')['input_ids']
20
+ final_tokens = input_ids
21
  with torch.inference_mode(), model.transformer.h.inference_session() as remote_transformer:
22
  for i in range(seq_length):
23
  h = model.transformer.word_embeddings(input_ids)
24
  h = model.transformer.word_embeddings_layernorm(h)
25
  h = remote_transformer.step(h)
26
+ h = model.transformer.ln_f(h)
27
+ h = F.linear(h, weight=model.transformer.word_embeddings.weight) # note: this line takes a while, will also be fixed
28
+ next_token_ix = torch.multinomial((h[0, -1] / 0.8).softmax(-1), 1)
29
+
30
+ final_tokens = torch.cat([final_tokens, next_token_ix], dim=-1)
31
+ input_ids = next_token_ix.view(1, 1)
32
+ return tokenizer.decode(final_tokens, skip_special_tokens=False)
33
+
34
  iface = gr.Interface(fn=inference, inputs="text", outputs="text")
35
  iface.launch()