Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -19,13 +19,14 @@ def get_model(model_name, model_path):
|
|
19 |
def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_length=200):
|
20 |
text += '\n'
|
21 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
|
|
22 |
with torch.no_grad():
|
23 |
out = model.generate(input_ids,
|
24 |
do_sample=True,
|
25 |
num_beams=n_beams,
|
26 |
temperature=temperature,
|
27 |
top_p=top_p,
|
28 |
-
max_length=max_length,
|
29 |
)
|
30 |
|
31 |
return list(map(tokenizer.decode, out))[0]
|
@@ -41,7 +42,7 @@ st.image(image, caption='NeuroKorzh')
|
|
41 |
|
42 |
st.markdown("\n")
|
43 |
|
44 |
-
text = st.text_input('Starting point for text generation', 'Что делать, Макс?'
|
45 |
button = st.button('Go')
|
46 |
|
47 |
if button:
|
|
|
19 |
def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_length=200):
|
20 |
text += '\n'
|
21 |
input_ids = tokenizer.encode(text, return_tensors="pt")
|
22 |
+
length_of_prompt = len(input_ids[0])
|
23 |
with torch.no_grad():
|
24 |
out = model.generate(input_ids,
|
25 |
do_sample=True,
|
26 |
num_beams=n_beams,
|
27 |
temperature=temperature,
|
28 |
top_p=top_p,
|
29 |
+
max_length=max_length + length_of_prompt,
|
30 |
)
|
31 |
|
32 |
return list(map(tokenizer.decode, out))[0]
|
|
|
42 |
|
43 |
st.markdown("\n")
|
44 |
|
45 |
+
text = st.text_input(label='Starting point for text generation', value='Что делать, Макс?')
|
46 |
button = st.button('Go')
|
47 |
|
48 |
if button:
|