vladyur commited on
Commit
b86439f
1 Parent(s): da9ee26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -19,13 +19,14 @@ def get_model(model_name, model_path):
19
  def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_length=200):
20
  text += '\n'
21
  input_ids = tokenizer.encode(text, return_tensors="pt")
 
22
  with torch.no_grad():
23
  out = model.generate(input_ids,
24
  do_sample=True,
25
  num_beams=n_beams,
26
  temperature=temperature,
27
  top_p=top_p,
28
- max_length=max_length,
29
  )
30
 
31
  return list(map(tokenizer.decode, out))[0]
@@ -41,7 +42,7 @@ st.image(image, caption='NeuroKorzh')
41
 
42
  st.markdown("\n")
43
 
44
- text = st.text_input('Starting point for text generation', 'Что делать, Макс?', height=100)
45
  button = st.button('Go')
46
 
47
  if button:
 
19
  def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_length=200):
20
  text += '\n'
21
  input_ids = tokenizer.encode(text, return_tensors="pt")
22
+ length_of_prompt = len(input_ids[0])
23
  with torch.no_grad():
24
  out = model.generate(input_ids,
25
  do_sample=True,
26
  num_beams=n_beams,
27
  temperature=temperature,
28
  top_p=top_p,
29
+ max_length=max_length + length_of_prompt,
30
  )
31
 
32
  return list(map(tokenizer.decode, out))[0]
 
42
 
43
  st.markdown("\n")
44
 
45
+ text = st.text_input(label='Starting point for text generation', value='Что делать, Макс?')
46
  button = st.button('Go')
47
 
48
  if button: