irfantea commited on
Commit
987ad3f
1 Parent(s): e3af0d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
  import streamlit as st
3
 
@@ -17,8 +20,8 @@ sentence = st.text_input('Input your sentence here:', value='My favorite ice cre
17
 
18
  st.info("Max generated sentence: 100 words")
19
  if (st.button("Generate")):
20
- input_ids = tokenizer.encode(sentence, return_tensors='pt')
21
- paragraph_generated = model.generate(input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
22
- text = tokenizer.decode(paragraph_generated[0], skip_special_tokens=True)
23
 
24
  st.write(text)
 
1
+ import torch
2
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
3
+
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
  import streamlit as st
6
 
 
20
 
21
  st.info("Max generated sentence: 100 words")
22
  if (st.button("Generate")):
23
+ input_ids = tokenizer.encode(sentence, return_tensors='pt').to(device)
24
+ paragraph_generated = model.generate(input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True).to(device)
25
+ text = tokenizer.decode(paragraph_generated[0], skip_special_tokens=True).to(device)
26
 
27
  st.write(text)